forked from google-ai-edge/ai-edge-torch
- Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllama.py
196 lines (165 loc) · 6.44 KB
/
llama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example of building Llama 3.2 models."""
fromfunctoolsimportpartial
importmath
fromtypingimportTuple
importai_edge_torch.generative.layers.model_configascfg
fromai_edge_torch.generative.utilitiesimportmodel_builder
importtorch
TENSOR_NAMES=model_builder.TENSOR_NAMES
def_build_llama3_rope_cache(
input_pos: torch.Tensor,
n_elem: int,
base: int,
condense_ratio: int,
dtype: torch.dtype,
device: torch.device,
factor: float,
low_freq_factor: float,
high_freq_factor: float,
max_seq_len: int,
) ->Tuple[torch.Tensor, torch.Tensor]:
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
It's a modified version of attn_utils.build_rope_cache with additional
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
and Cos values with scaling factors for quick lookup during the inference.
Reference:
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
Args:
input_pos (torch.Tensor): the given input sequence positions
n_elem (int): Each sequence's dimmension.
base (int): Rope base value.
condense_ratio (int): The ratio by which sequence indicies are condensed.
dtype (torch.dtype): Output tensor's data type.
device (torch.device): Output tensor's data type.
factor (float): Factor to scale theta down for tokens in long range in the
sequence.
low_freq_factor (float): Factor to determine if tokens are in long range
in the sequence.
high_freq_factor (float): Factor to determine if tokens are in short range
in the sequence.
max_seq_len (int): The original token sequence length before extending
ROPE to support longer sequence.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
"""
theta=1.0/ (base** (torch.arange(0, n_elem, 2).float() /n_elem))
low_freq_wavelen=max_seq_len/low_freq_factor
high_freq_wavelen=max_seq_len/high_freq_factor
wavelen=2*math.pi/theta
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
theta=torch.where(wavelen>low_freq_wavelen, theta/factor, theta)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor= (max_seq_len/wavelen-low_freq_factor) / (
high_freq_factor-low_freq_factor
)
smoothed_theta= (1-smooth_factor) *theta/factor+smooth_factor*theta
is_medium=~(wavelen<high_freq_wavelen) *~(wavelen>low_freq_wavelen)
theta=torch.where(is_medium, smoothed_theta, theta)
seq_idx=input_pos/condense_ratio
idx_theta=torch.outer(seq_idx, theta)
cos=torch.cos(idx_theta).to(dtype=dtype, device=device)
sin=torch.sin(idx_theta).to(dtype=dtype, device=device)
returncos, sin
classLlama(model_builder.DecoderOnlyModel):
"""A Llama model built from the Edge Generative API layers.
Llama 3.2 shares the same architecture as TinyLlama except ROPE calculation.
"""
def__init__(self, config: cfg.ModelConfig):
super().__init__(config)
attn_config=self.config.block_config(0).attn_config
defget_1b_model_config(kv_cache_max_len: int=1024) ->cfg.ModelConfig:
"""Returns the model config for a Llama 3.2-1B model.
Args:
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
is 1024.
Returns:
The model config for a SmolLM model.
"""
attn_config=cfg.AttentionConfig(
num_heads=32,
head_dim=64,
num_query_groups=8,
rotary_base=500000,
rotary_percentage=1.0,
)
ff_config=cfg.FeedForwardConfig(
type=cfg.FeedForwardType.GATED,
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
intermediate_size=8192,
)
norm_config=cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
block_config=cfg.TransformerBlockConfig(
attn_config=attn_config,
ff_config=ff_config,
pre_attention_norm_config=norm_config,
post_attention_norm_config=norm_config,
)
max_seq_len=8192
# Create the RoPE callable
build_rope=partial(
_build_llama3_rope_cache,
condense_ratio=1,
dtype=torch.float32,
device=torch.device("cpu"),
factor=32.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
max_seq_len=max_seq_len,
)
config=cfg.ModelConfig(
vocab_size=128256,
num_layers=16,
max_seq_len=max_seq_len,
embedding_dim=2048,
kv_cache_max_len=kv_cache_max_len,
block_configs=block_config,
final_norm_config=norm_config,
enable_hlfb=True,
build_rope=build_rope,
)
returnconfig
defget_3b_model_config(kv_cache_max_len: int=1024) ->cfg.ModelConfig:
"""Returns the model config for a Llama 3.2-3B model."""
config=get_1b_model_config(kv_cache_max_len)
# Llama 3.2 has only one block config.
attn_config=config.block_config(0).attn_config
attn_config.num_heads=24
attn_config.head_dim=128
config.num_layers=28
config.embedding_dim=3072
returnconfig
defget_fake_model_config(**kwargs) ->cfg.ModelConfig:
config=get_1b_model_config(**kwargs)
config.vocab_size=128
config.num_layers=2
# SmolLM has only one block config.
config.block_config(0).ff_config.intermediate_size=64
returnconfig
def_build_model(
checkpoint_path: str, config: cfg.ModelConfig
) ->torch.nn.Module:
returnmodel_builder.build_decoder_only_model(
checkpoint_path=checkpoint_path,
config=config,
tensor_names=TENSOR_NAMES,
model_class=Llama,
)
defbuild_1b_model(checkpoint_path: str, **kwargs) ->torch.nn.Module:
return_build_model(checkpoint_path, get_1b_model_config(**kwargs))
defbuild_3b_model(checkpoint_path: str, **kwargs) ->torch.nn.Module:
return_build_model(checkpoint_path, get_3b_model_config(**kwargs))