- Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathphi3.py
218 lines (195 loc) · 6.81 KB
/
phi3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
fromfunctoolsimportpartial
importmath
fromtypingimportTuple
importai_edge_torch.generative.layers.model_configascfg
fromai_edge_torch.generative.utilitiesimportmodel_builder
importai_edge_torch.generative.utilities.loaderasloading_utils
importtorch
TENSOR_NAMES=loading_utils.ModelLoader.TensorNames(
ff_up_proj="model.layers.{}.mlp.gate_up_proj",
ff_down_proj="model.layers.{}.mlp.down_proj",
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
attn_output_proj="model.layers.{}.self_attn.o_proj",
pre_attn_norm="model.layers.{}.input_layernorm",
post_attn_norm="model.layers.{}.post_attention_layernorm",
embedding="model.embed_tokens",
final_norm="model.norm",
lm_head="lm_head",
)
# max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
ROPE_SCALE_FACTOR=32
# ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
# https://github.com/microsoft/LongRoPE, these values had been searched with
# min=1.0, step-0.01 to optimize the errors of sample dataset.
ROPE_SHORT_FACTOR= [
1.0,
1.0199999809265137,
1.0299999713897705,
1.0299999713897705,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0499999523162842,
1.0699999332427979,
1.0999999046325684,
1.1099998950958252,
1.1599998474121094,
1.1599998474121094,
1.1699998378753662,
1.2899998426437378,
1.339999794960022,
1.679999828338623,
1.7899998426437378,
1.8199998140335083,
1.8499997854232788,
1.8799997568130493,
1.9099997282028198,
1.9399996995925903,
1.9899996519088745,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0199997425079346,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0299997329711914,
2.0799996852874756,
2.0899996757507324,
2.189999580383301,
2.2199995517730713,
2.5899994373321533,
2.729999542236328,
2.749999523162842,
2.8399994373321533,
]
def_build_phi3_rope(
input_pos: int,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
device: torch.device,
theta_factors: torch.Tensor,
scale: float,
) ->Tuple[torch.Tensor, torch.Tensor]:
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
It's a modified version of attn_utils.build_rope_cache with additional
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
Cos values with scaling factors for quick lookup during the inference.
Args:
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int, optional): Rope base value.
condense_ratio (int, optional): The ratio by which sequence indicies are
condensed.
dtype (torch.dtype, optional): Output tensor's data type.
device (torch.device, optional): Output tensor's data type.
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
to scale the theta values.
scale (float, optional): A float used to scale the rope values.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta=1.0/ (base** (torch.arange(0, n_elem, 2).float() /n_elem))
theta=theta/theta_factors
seq_idx=input_pos/condense_ratio
idx_theta=torch.outer(seq_idx, theta)
cos=torch.cos(idx_theta).to(dtype=dtype, device=device) *scale
sin=torch.sin(idx_theta).to(dtype=dtype, device=device) *scale
returncos, sin
classPhi3_5Mini(model_builder.DecoderOnlyModel):
"""A Phi-3.5 model built from the Edge Generative API layers."""
pass
defget_model_config(kv_cache_max_len: int=1024) ->cfg.ModelConfig:
"""Returns the model config for a Phi-3.5 model.
Args:
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
is 1024.
Returns:
The model config for a Phi-3.5 model.
"""
attn_config=cfg.AttentionConfig(
num_heads=32,
head_dim=96,
num_query_groups=32,
rotary_base=10000,
rotary_percentage=1.0,
qkv_transpose_before_split=True,
)
ff_config=cfg.FeedForwardConfig(
type=cfg.FeedForwardType.SEQUENTIAL,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
intermediate_size=8192,
)
norm_config=cfg.NormalizationConfig(
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
)
block_config=cfg.TransformerBlockConfig(
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)
max_seq_len=4096
# Create the RoPE callable
build_rope=partial(
_build_phi3_rope,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
scale=math.sqrt(1+math.log(ROPE_SCALE_FACTOR) /math.log(max_seq_len)),
)
config=cfg.ModelConfig(
vocab_size=32064,
num_layers=32,
max_seq_len=max_seq_len,
kv_cache_max_len=kv_cache_max_len,
embedding_dim=3072,
block_configs=block_config,
final_norm_config=norm_config,
lm_head_share_weight_with_embedding=False,
enable_hlfb=True,
build_rope=build_rope,
)
returnconfig
defget_fake_model_config(kv_cache_max_len: int=128) ->cfg.ModelConfig:
config=get_model_config(kv_cache_max_len)
config.vocab_size=128
config.num_layers=2
config.max_seq_len=2*kv_cache_max_len
# Phi-3.5 has only one block config.
config.block_config(0).ff_config.intermediate_size=128
returnconfig
defbuild_model(checkpoint_path: str, **kwargs) ->torch.nn.Module:
"""Instantiates the model instance and load checkpoint if provided."""
returnmodel_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=get_model_config(**kwargs),
tensor_names=TENSOR_NAMES,
model_class=Phi3_5Mini,
)