This repository was archived by the owner on Apr 24, 2025. It is now read-only.
- Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathgrpo_interface.py
484 lines (425 loc) · 16.8 KB
/
grpo_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
importcollections
importdataclasses
importfunctools
fromtypingimport*
importtorch
importtorch.distributedasdist
importrealhf.api.core.model_apiasmodel_api
importrealhf.base.constantsasconstants
importrealhf.base.loggingaslogging
fromrealhf.api.core.data_apiimportSequenceSample
fromrealhf.base.datapackimportflat2d
logger=logging.getLogger("GRPO Interface")
def_grpo_loss(
logits: torch.FloatTensor, # [tot_seqlen, vocab_size]
input_: SequenceSample,
kl_adapter: Any, # const
eps_clip: float, # const
early_stop_imp_ratio: Optional[float], # const
early_stop_kl: Optional[float], # const
) ->Tuple[torch.FloatTensor, Dict]:
# NOTE: import here to avoid cuda initialization
importrealhf.impl.model.utils.ppo_functionalasppo_functional
fromrealhf.impl.model.utils.functionalimport (
apply_logits_mask,
gather_packed_shifted_log_probs,
)
packed_input_ids=input_.data["packed_input_ids"]
seqlens=torch.tensor(flat2d(input_.seqlens["packed_input_ids"]), device="cuda")
cu_seqlens=torch.nn.functional.pad(seqlens.cumsum(0), (1, 0)).int()
ppo_loss_mask=input_.data["ppo_loss_mask"]
advantages=input_.data["advantages"].float()
old_logp=input_.data["old_logp"].float()
ref_logp=input_.data["ref_logp"].float()
logits_mask=input_.data["packed_logits_mask"]
iflogits_maskisnotNone:
apply_logits_mask(logits, logits_mask)
logprobs=gather_packed_shifted_log_probs(
logits, cu_seqlens, packed_input_ids
).float()
loss, ppo_stat=ppo_functional.actor_loss_fn(
logprobs=logprobs,
old_logprobs=old_logp,
advantages=advantages,
eps_clip=eps_clip,
loss_mask=ppo_loss_mask,
)
ref_kl=ref_logp-logprobs
token_denorm=ppo_loss_mask.count_nonzero().float()
actor_loss=torch.where(ppo_loss_mask, loss.detach(), 0.0).sum()
kl_loss=torch.where(ppo_loss_mask, ref_kl.exp() -ref_kl-1, 0.0).sum()
loss+=kl_adapter.value*kl_loss/ppo_loss_mask.count_nonzero()
kl_loss=kl_loss.detach()
importance_weight=ppo_stat["importance_weight"].float() *token_denorm
clip_ratio=ppo_stat["clip_ratio"].float() *token_denorm
approx_kl=ppo_stat["approx_kl"].float() *token_denorm
advantages=torch.where(ppo_loss_mask, advantages, 0.0).sum()
dist.all_reduce_coalesced(
[
token_denorm,
importance_weight,
clip_ratio,
approx_kl,
actor_loss,
kl_loss,
advantages,
],
group=constants.data_parallel_group(),
)
# Early stopping.
kl_adapter.update(kl_loss/token_denorm, n_steps=cu_seqlens.shape[0] -1)
_imp=importance_weight/token_denorm
_kl=approx_kl/token_denorm
ifearly_stop_imp_ratioisnotNoneand_imp>early_stop_imp_ratio:
logger.warning(
f"Current importance ratio {_imp.item():.4f} is larger "
f"than early stop threshold {early_stop_imp_ratio}. Abandon this minibatch."
)
loss=loss*0.0
ifearly_stop_klisnotNoneand_kl>early_stop_kl:
logger.warning(
f"Current approximate KL divergence {_kl.item():.4f} is larger "
f"than early stop threshold {early_stop_kl}. Abort actor update."
)
loss=loss*0.0
stats=dict(
ppo_approx_kl=approx_kl,
actor_loss=actor_loss,
kl_loss=kl_loss,
actor_clip_ratio=clip_ratio,
token_denorm=token_denorm,
advantages=advantages,
importance_weight=importance_weight,
)
returnloss, stats
@dataclasses.dataclass
classGRPOInterface(model_api.ModelInterface):
group_size: int
n_minibatches: int=4
generation_config: model_api.GenerationHyperparameters=dataclasses.field(
default_factory=model_api.GenerationHyperparameters
)
kl_ctl: float=0.1
adv_norm: bool=True
discount: float=0.99
eps_clip: float=0.2
max_reward_clip: float=5.0
early_stop_kl: Optional[float] =None# e.g. 0.1
early_stop_imp_ratio: Optional[float] =None# e.g., 10.0
adaptive_kl_ctl: bool=False
adaptive_kl_target: Optional[float] =6
adaptive_kl_horizon: Optional[float] =10000
enable_save: bool=True
def__post_init__(self):
fromrealhf.impl.model.utilsimportppo_functional
ifself.adaptive_kl_ctl:
assertself.adaptive_kl_targetisnotNone
assertself.adaptive_kl_horizonisnotNone
self.kl_adapter=ppo_functional.AdaptiveKLController(
self.kl_ctl, self.adaptive_kl_target, self.adaptive_kl_horizon
)
else:
self.kl_adapter=ppo_functional.FixedKLController(self.kl_ctl)
self.kl_ctl=None
defsave(self, model: model_api.Model, save_dir: str):
# NOTE: import here to avoid cuda initialization
fromrealhf.impl.model.nn.real_llm_apiimportReaLModel
ifnotself.enable_save:
return
module=model.module
ifnotisinstance(module, ReaLModel):
module=module.module
module.save_to_hf(
tokenizer=model.tokenizer,
save_dir=save_dir,
)
@torch.no_grad()
defgenerate(
self, model: model_api.Model, input_: SequenceSample, n_mbs=None
) ->SequenceSample:
# NOTE: import here to avoid cuda initialization
fromrealhf.impl.model.nn.real_llm_generateimport (
concat_prompt_to_generation_output,
)
module=model.module
module.eval()
# Repeat the prompt for `self.group_size` times.
packed_input_ids=input_.data["packed_input_ids"]
new_input_ids= []
offset=0
forxininput_.seqlens["packed_input_ids"]:
new_input_ids+= [
packed_input_ids[offset : offset+x[0]]
] *self.group_size
offset+=x[0]
assertoffset==sum([x[0] forxininput_.seqlens["packed_input_ids"]])
grouped_input=SequenceSample.from_default(
ids=list(range(input_.bs*self.group_size)),
seqlens=[
int(x[0])
for_inrange(self.group_size)
forxininput_.seqlens["packed_input_ids"]
],
data=dict(packed_input_ids=torch.cat(new_input_ids)),
)
res=module.generate(
input_=grouped_input,
tokenizer=model.tokenizer,
gconfig=self.generation_config,
num_micro_batches=n_mbs,
)
ifresisNone:
returnNone
gen_tokens, logprobs, logits_mask, *_=res
pad_token_id=model.tokenizer.pad_token_id
eos_token_id=model.tokenizer.eos_token_id
# We also want gen_lengths to include the eos token, where the reward model outputs a score for this sequence.
gen_lengths= (gen_tokens!=pad_token_id).logical_and(
gen_tokens!=eos_token_id
).sum(dim=-1) +1
gen_lengths=gen_lengths.clip(max=gen_tokens.shape[-1])
(
packed_input_ids,
packed_logprobs,
packed_logits_mask,
seq_lengths,
prompt_mask,
) =concat_prompt_to_generation_output(
packed_prompts=grouped_input.data["packed_input_ids"],
prompt_lengths=torch.tensor(
flat2d(grouped_input.seqlens["packed_input_ids"]), device=model.device
),
gen_tokens=gen_tokens,
logprobs=logprobs,
logits_mask=logits_mask,
gen_lengths=gen_lengths,
)
# Partition generated data into groups.
seqlens= [
seq_lengths[i*self.group_size : (i+1) *self.group_size].cpu().int()
foriinrange(input_.bs)
]
data=dict(
packed_input_ids=packed_input_ids,
prompt_mask=prompt_mask,
packed_logprobs=packed_logprobs,
packed_logits_mask=(
packed_logits_mask.bool()
ifnotself.generation_config.force_no_logits_mask
andpacked_logits_maskisnotNone
elseNone
),
)
res=SequenceSample(
keys=[
"packed_input_ids",
"prompt_mask",
"packed_logprobs",
"packed_logits_mask",
],
trailing_shapes=dict(
packed_input_ids=(),
prompt_mask=(),
packed_logprobs=(),
packed_logits_mask=(packed_logits_mask.shape[-1],),
),
dtypes=dict(
packed_input_ids=torch.long,
prompt_mask=torch.bool,
packed_logprobs=torch.float,
packed_logits_mask=torch.bool,
),
seqlens=dict(
packed_input_ids=seqlens,
packed_logits_mask=seqlens,
packed_logprobs=[x-1forxinseqlens],
prompt_mask=seqlens,
),
data=data,
ids=input_.ids,
)
returnres
@torch.no_grad()
definference(
self, model: model_api.Model, input_: SequenceSample, n_mbs=None
) ->SequenceSample:
fromrealhf.impl.model.utils.functionalimport (
apply_logits_mask,
gather_packed_shifted_log_probs,
)
module=model.module
module.eval()
# This post_hook will gather log probabilities in mini-batches,
# reducing peak memory usage.
defcalc_logprobs(logits, input_):
input_lens=torch.tensor(flat2d(input_.seqlens["packed_input_ids"]))
cu_seqlens=torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int()
logits/=self.generation_config.temperature
if (
"packed_logits_mask"ininput_.data
andinput_.data["packed_logits_mask"] isnotNone
):
apply_logits_mask(logits, input_.data["packed_logits_mask"])
logprobs=gather_packed_shifted_log_probs(
logits, cu_seqlens, input_.data["packed_input_ids"]
)
returnlogprobs
logprobs=module.forward(
input_=input_,
num_micro_batches=n_mbs,
post_hook=calc_logprobs,
)
iflogprobsisNone:
returnNone
res=SequenceSample(
keys=["packed_ref_logprobs"],
ids=input_.ids,
dtypes=dict(packed_ref_logprobs=logprobs.dtype),
trailing_shapes=dict(packed_ref_logprobs=()),
data=dict(packed_ref_logprobs=logprobs),
seqlens=dict(
packed_ref_logprobs=[
[xx-1forxxinx] forxininput_.seqlens["packed_input_ids"]
]
),
)
returnres
deftrain_step(
self, model: model_api.Model, input_: SequenceSample, n_mbs=None
) ->Dict:
# NOTE: import here to avoid cuda initialization
fromrealhf.impl.model.utils.functionalimportmasked_normalization
fromrealhf.impl.model.utils.ppo_functionalimport (
get_packed_advantages_and_returns,
)
module=model.module
module.eval()
# Get the useful sequence length indices.
seqlens=torch.tensor(
flat2d(input_.seqlens["packed_input_ids"]), device=model.device
)
cu_seqlens=torch.nn.functional.pad(seqlens.cumsum(0), (1, 0)).int()
short1seqlens=seqlens-1
short1cu_seqlens=torch.nn.functional.pad(
short1seqlens.cumsum(0), (1, 0)
).int()
shift_one_indices=torch.cat(
[
torch.arange(
cu_seqlens[i] +1,
cu_seqlens[i+1],
dtype=torch.long,
device=cu_seqlens.device,
)
foriinrange(cu_seqlens.shape[0] -1)
]
)
# Get loss mask that filters prompts out.
loss_mask=input_.data[f"prompt_mask"][shift_one_indices].logical_not()
# Apply the mask to log probabilities.
input_.data["packed_ref_logprobs"] *=loss_mask
input_.data["packed_logprobs"] *=loss_mask
# Gather rewards for all groups and normalize them.
group_rewards=input_.data["rewards"].view(-1, self.group_size)
rewards_mean=group_rewards.mean(1, keepdim=True)
rewards_std=group_rewards.std(1, keepdim=True)
all_rewards= (
((group_rewards-rewards_mean) / (rewards_std+1e-5))
.clip(-self.max_reward_clip, self.max_reward_clip)
.view(-1)
)
assertall_rewards.shape[0] ==input_.bs*self.group_size, (
all_rewards.shape,
input_.bs,
self.group_size,
)
# Compute episode-level rewards.
episode_rewards=torch.zeros(
int(short1seqlens.sum()), dtype=torch.float32, device=model.device
)
episode_rewards.scatter_(
0,
short1seqlens.cumsum(0) -1,
all_rewards,
)
# Get discounted reward.
adv, _=get_packed_advantages_and_returns(
gamma=1.0,
lam=self.discount,
rewards=episode_rewards,
values=torch.zeros(
int(seqlens.sum()), dtype=torch.float32, device=model.device
),
short1cu_seqlens=short1cu_seqlens,
seq_no_eos_mask=torch.zeros(
input_.bs*self.group_size, dtype=torch.bool, device=model.device
),
)
# Optionally normalize computed advantages.
ifself.adv_norm:
adv=masked_normalization(adv, mask=loss_mask)
# Unpack grouped inputs to individual sequences for training.
data_=SequenceSample.from_default(
seqlens=[int(x) forxinseqlens.cpu().numpy().tolist()],
data=dict(
packed_input_ids=input_.data["packed_input_ids"],
ppo_loss_mask=loss_mask,
advantages=adv,
old_logp=input_.data["packed_logprobs"],
ref_logp=input_.data["packed_ref_logprobs"],
packed_logits_mask=(
None
ifself.generation_config.force_no_logits_mask
elseinput_.data.get("packed_logits_mask")
),
),
ids=list(range(input_.bs*self.group_size)),
)
# Split mini-batches and run PPO training. Mini-batches have balanced sizes
datas=data_.split(self.n_minibatches, min_size=data_.bs//self.n_minibatches)
train_stats=collections.defaultdict(float)
fordataindatas:
stats=module.train_batch(
input_=data,
version_steps=model.version.global_step,
loss_fn=functools.partial(
_grpo_loss,
kl_adapter=self.kl_adapter,
eps_clip=self.eps_clip,
early_stop_imp_ratio=self.early_stop_imp_ratio,
early_stop_kl=self.early_stop_kl,
),
num_micro_batches=n_mbs,
)
ifstats:
fork, vinstats.items():
train_stats[k] +=v
model.inc_version()
# Logging.
rewards_group_mean_sq=group_rewards.square().mean(1).sum()
rewards_group_mean=rewards_mean.sum()
bs=torch.tensor([input_.bs], device=model.device, dtype=torch.float32)
dist.all_reduce_coalesced(
[rewards_group_mean, rewards_group_mean_sq, bs],
group=constants.data_parallel_group(),
)
rewards_mean=float(rewards_group_mean/bs)
rewards_std=float(torch.sqrt(rewards_group_mean_sq/bs-rewards_mean**2))
ifstats:
token_denorm=int(stats["token_denorm"])
stats=dict(
ppo_approx_kl=float(stats["ppo_approx_kl"]) /token_denorm,
actor_loss=float(stats["actor_loss"]) /token_denorm,
kl_loss=float(stats["kl_loss"]) /token_denorm,
kl_ctl=self.kl_adapter.value,
actor_clip_ratio=float(stats["actor_clip_ratio"]) /token_denorm,
importance_weight=float(stats["importance_weight"]) /token_denorm,
advantages=float(stats["advantages"]) /token_denorm,
rewards=rewards_mean,
rewards_std=rewards_std,
# FIXME: It only logs the MoE aux loss of the final PPO mini-batch.
**constants.log_global_stats_tracker(
return_dict=True, clear_stats_after_logging=True
),
)
returndict(stats) ifstatselse {}