- Notifications
You must be signed in to change notification settings - Fork 358
/
Copy pathsess_megatron.py
302 lines (231 loc) · 10.1 KB
/
sess_megatron.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
importsys
importos
importpathlib
importtorch
importtraceback
importnumpyasnp
fromtypingimportList, Tuple
frommegatron_miniimportget_args
frommegatron_mini.initializeimportinitialize_megatron
frommegatron_mini.modelimportLLaMAModel
frommegatron_mini.utilsimportget_model_for_infer, Tokenizer
defprint_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
iftorch.distributed.is_initialized():
iftorch.distributed.get_rank() ==0:
print(message, flush=True, file=sys.stderr)
else:
print(message, flush=True, file=sys.stderr)
defadd_code_generation_args(parser):
"""Code generation arguments."""
group=parser.add_argument_group(title="code generation")
group.add_argument(
"--padded_vocab_size",
type=int,
default=40000,
help="Start id for whitespace encoding",
)
group.add_argument("--model_dir", type=str, default="")
group.add_argument("--model_name", type=str, default="aix3-7b-base")
returnparser
classPredictor(object):
def__init__(self, args):
self.args=args
self.checkpoint_head_hash: str=""
self.np_rand=np.random.RandomState(seed=1414)
# build predictor
self.tokenizer=self.create_tokenizer()
self.dtype=torch.float32
ifself.args.bf16:
self.dtype=torch.bfloat16
elifself.args.fp16:
self.dtype=torch.half
self.predictor=self.create_predictor()
iftorch.distributed.is_initialized():
torch.distributed.barrier()
@staticmethod
defmodel_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("Building Codemodel ...")
model=LLaMAModel(parallel_output=False)
returnmodel
@staticmethod
defpad_batch(tokens_id, max_seq_len=2048):
"""
pad_batch was used by syncing token_ids
"""
tokens_id=np.reshape(tokens_id, [1, -1])
context_length=tokens_id.shape[-1]
assertcontext_length<=max_seq_len, f"{context_length}, {max_seq_len}"
ifcontext_length<max_seq_len:
tokens_id=np.concatenate([tokens_id, np.zeros(shape=[1, max_seq_len-context_length], dtype=tokens_id.dtype)], axis=-1)
returntokens_id.astype(np.int64), np.array([context_length], dtype=np.int64)
@staticmethod
defsync_type_info(sess_id: int) ->int:
input_info=np.array([sess_id], dtype=np.int64)
input_info_tensor=torch.tensor(input_info, dtype=torch.int64, device='cuda')
torch.distributed.broadcast(
input_info_tensor,
0,
)
sess_id=input_info_tensor[0].item()
returnsess_id
@staticmethod
defsync_obj_info(model_dir: str) ->str:
tmp_list= [model_dir]
torch.distributed.broadcast_object_list(
tmp_list,
0,
)
returntmp_list[0]
defcreate_predictor(self):
model_dir=self.args.model_dir
assertself.args.num_attention_heads%self.args.tensor_model_parallel_size==0
assertself.args.hidden_size%self.args.num_attention_heads==0
model=get_model_for_infer(self.model_provider)
print_rank_0("Loading state dict ...")
_=self.load_checkpoint(model, model_dir)
assertlen(model) ==1, "Above condition should have caught this"
model=model[0]
model.eval()
ifself.args.bf16orself.args.fp16 :
print_rank_0(f" > converting model to {'bf16'ifself.args.bf16else'fp16'} ...")
model.to(self.dtype)
print_rank_0(f" > moving model to GPU ...")
model.cuda(torch.cuda.current_device())
returnmodel
defcreate_tokenizer(self):
assertos.path.exists(os.path.join(self.args.model_dir, "tokenizer.model"))
tokenizer=Tokenizer(model_path=os.path.join(self.args.model_dir, "tokenizer.model"))
returntokenizer
defload_checkpoint(self, model: List[LLaMAModel], path):
assertisinstance(model, list)
ifnot (pathisnotNoneandos.path.exists(path)):
raiseValueError
iteration=0
ifself.args.tensor_model_parallel_size==1andself.args.rank<self.args.tensor_model_parallel_size:
checkpoint_name=os.path.join(path, f"{self.args.model_name}.pt")
assertos.path.isfile(checkpoint_name)
elifself.args.rank<self.args.tensor_model_parallel_size:
checkpoints=sorted(pathlib.Path(path).glob(f"{self.args.model_name}_states_*.pt"))
assertlen(checkpoints) ==self.args.tensor_model_parallel_size
checkpoint_name=checkpoints[self.args.rank]
else:
raiseValueError
# Load the checkpoint.
print(f"rank_{self.args.rank} load: {checkpoint_name}", flush=True, file=sys.stderr)
state_dict=torch.load(checkpoint_name, map_location="cpu")
# Set iteration.
iteration=state_dict.get("iteration", 0)
if"model"instate_dict:
state_dict=state_dict["model"]
if"module"instate_dict:
state_dict=state_dict["module"]
# Model.
model[0].load_state_dict(state_dict, strict=True)
print_rank_0(
f"successfully loaded checkpoint from {path} "
f"at iteration {iteration}"
)
returniteration
defpredict_batch(self, data):
common_len=int(data[1].item())
withtorch.no_grad():
tokens_ids=data[0].clone().detach().cuda()
logits=self.predictor(
tokens=tokens_ids, # shape: [bsz, 1024]
start_pos=common_len,
)
logits=logits[:, -1].view(1, -1).contiguous()
probs=torch.softmax(logits, dim=-1).cpu().numpy()
return [np.squeeze(probs)]
defpredict(self, token_ids: List[int], common_len: int) ->Tuple[List[int], List[float]]:
iftorch.distributed.is_initialized():
torch.distributed.barrier()
try:
common_len_nda=np.array([common_len]).astype("int64")
token_ids_nda=np.array([token_ids], dtype=np.int64)
max_pad_len=max(token_ids_nda.shape[-1], 128)
max_pad_len=self.sync_type_info(max_pad_len)
token_ids_nda, tokens_id_len=self.pad_batch(token_ids_nda, max_seq_len=max_pad_len)
context_tensor=torch.tensor(token_ids_nda, dtype=torch.int64, device='cuda')
context_tensor_length=torch.tensor(tokens_id_len, dtype=torch.int64, device='cuda')
context_common_len=torch.tensor(common_len_nda, dtype=torch.int64, device='cuda')
torch.distributed.broadcast(
context_tensor,
0,
)
torch.distributed.broadcast(
context_tensor_length,
0,
)
torch.distributed.broadcast(
context_common_len,
0,
)
tokens_id_len=context_tensor_length.min().item()
batch= [context_tensor[:, :tokens_id_len], context_common_len]
out=self.predict_batch(batch)
# shape: [bsz, vocab_size] => [vocab_size]
out=out[0]
predict_id=np.argmax(out)
return [int(predict_id)], [out[predict_id]]
exceptExceptionase:
traceback.print_exc(file=sys.stderr)
raiseRuntimeError(e)
classTestInference:
def__init__(self) ->None:
aix_config= {
"num_layers": 32, "hidden_size": 4096, "num_attention_heads": 32,
"max_position_embeddings": 32768, "fp16": False, "bf16": True,
"rope_theta": 256000, "inner_hidden_dim": 14464, "padded_vocab_size": 49152,
"seq_length": 4096, "micro_batch_size": 1, "use_flash_attn": True,
"use_cpu_initialization": True, "attention_head_type": "groupedquery"
}
initialize_megatron(
extra_args_provider=add_code_generation_args,
aix_config=aix_config
)
args=get_args()
self.sess=Predictor(args=args)
self.end_token_set=self.sess.tokenizer.end_token_set
defrun_infer(self, code_string: str, max_new_tokens: int=256, later_code: str="", file_path: str="") ->None:
tokens=self.sess.tokenizer.encode(
code_string=code_string, later_code=later_code, file_path=file_path
)
iflen(tokens) ==0:
returnself.sess.sync_obj_info("")
predict_list= []
common_len=0
whileTrue:
iftorch.distributed.get_rank() ==0:
output_vals=self.sess.predict(
np.array([tokens], dtype='int32'),
np.array([common_len], dtype='int32')
)
predict_list.append(output_vals[0][0])
iflen(predict_list) >=max_new_tokensorpredict_list[-1] inself.end_token_set:
terminate_runs=1
else:
terminate_runs=0
common_len+=len(tokens)
tokens=predict_list[-1:]
else:
tokens= [0] *4
output_vals=self.sess.predict([], [], input_vals=[
np.array([tokens], dtype='int32'),
np.array([0], dtype='int32')
])
predict_list.append(0)
terminate_runs=0
ifself.sess.sync_type_info(terminate_runs) >0:
break
returnself.sess.sync_obj_info(self.sess.tokenizer.decode(predict_list))
if__name__=="__main__":
infer=TestInference()
res=infer.run_infer(
code_string="""# 快速排序算法""",
later_code="\n",
file_path="test.py"
)
print(res)