- Notifications
You must be signed in to change notification settings - Fork 358
/
Copy pathparallel_state.py
522 lines (422 loc) · 20.9 KB
/
parallel_state.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
importtorch
fromtypingimportOptional
from .utilsimportGlobalMemoryBuffer
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP=None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP=None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP=None
# Embedding group.
_EMBEDDING_GROUP=None
# Position embedding group.
_POSITION_EMBEDDING_GROUP=None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP=None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK=None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE=None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK=None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE=None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE=None
_MPU_TENSOR_MODEL_PARALLEL_RANK=None
_MPU_PIPELINE_MODEL_PARALLEL_RANK=None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS=None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS=None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS=None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS=None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER=None
definitialize_model_parallel(
tensor_model_parallel_size: int=1,
pipeline_model_parallel_size: int=1,
virtual_pipeline_model_parallel_size: Optional[int] =None,
pipeline_model_parallel_split_rank: Optional[int] =None,
) ->None:
"""
Initialize model data parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# Get world size and rank. Ensure some consistencies.
asserttorch.distributed.is_initialized()
world_size: int=torch.distributed.get_world_size()
ifworld_size% (tensor_model_parallel_size*pipeline_model_parallel_size) !=0:
raiseRuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
data_parallel_size: int=world_size// (tensor_model_parallel_size*
pipeline_model_parallel_size)
num_tensor_model_parallel_groups: int=world_size//tensor_model_parallel_size
num_pipeline_model_parallel_groups: int=world_size//pipeline_model_parallel_size
num_data_parallel_groups: int=world_size//data_parallel_size
ifvirtual_pipeline_model_parallel_sizeisnotNone:
ifnotpipeline_model_parallel_size>2:
raiseRuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK=0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE=virtual_pipeline_model_parallel_size
ifpipeline_model_parallel_split_rankisnotNone:
global_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK=pipeline_model_parallel_split_rank
rank=torch.distributed.get_rank()
# Build the data-parallel groups.
global_DATA_PARALLEL_GROUP
global_DATA_PARALLEL_GLOBAL_RANKS
assert_DATA_PARALLEL_GROUPisNone, 'data parallel group is already initialized'
all_data_parallel_group_ranks= []
foriinrange(pipeline_model_parallel_size):
start_rank=i*num_pipeline_model_parallel_groups
end_rank= (i+1) *num_pipeline_model_parallel_groups
forjinrange(tensor_model_parallel_size):
ranks=range(start_rank+j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group=torch.distributed.new_group(ranks)
ifrankinranks:
_DATA_PARALLEL_GROUP=group
_DATA_PARALLEL_GLOBAL_RANKS=ranks
# Build the model-parallel groups.
global_MODEL_PARALLEL_GROUP
assert_MODEL_PARALLEL_GROUPisNone, 'model parallel group is already initialized'
foriinrange(data_parallel_size):
ranks= [data_parallel_group_ranks[i]
fordata_parallel_group_ranksinall_data_parallel_group_ranks]
group=torch.distributed.new_group(ranks)
ifrankinranks:
_MODEL_PARALLEL_GROUP=group
# Build the tensor model-parallel groups.
global_TENSOR_MODEL_PARALLEL_GROUP
assert_TENSOR_MODEL_PARALLEL_GROUPisNone, \
'tensor model parallel group is already initialized'
foriinrange(num_tensor_model_parallel_groups):
ranks=range(i*tensor_model_parallel_size,
(i+1) *tensor_model_parallel_size)
group=torch.distributed.new_group(ranks)
ifrankinranks:
_TENSOR_MODEL_PARALLEL_GROUP=group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
global_PIPELINE_MODEL_PARALLEL_GROUP
global_PIPELINE_GLOBAL_RANKS
assert_PIPELINE_MODEL_PARALLEL_GROUPisNone, \
'pipeline model parallel group is already initialized'
global_EMBEDDING_GROUP
global_EMBEDDING_GLOBAL_RANKS
assert_EMBEDDING_GROUPisNone, 'embedding group is already initialized'
global_POSITION_EMBEDDING_GROUP
global_POSITION_EMBEDDING_GLOBAL_RANKS
assert_POSITION_EMBEDDING_GROUPisNone, \
'position embedding group is already initialized'
foriinrange(num_pipeline_model_parallel_groups):
ranks=range(i, world_size, num_pipeline_model_parallel_groups)
group=torch.distributed.new_group(ranks)
ifrankinranks:
_PIPELINE_MODEL_PARALLEL_GROUP=group
_PIPELINE_GLOBAL_RANKS=ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
iflen(ranks) >1:
embedding_ranks= [ranks[0], ranks[-1]]
position_embedding_ranks= [ranks[0]]
ifpipeline_model_parallel_split_rankisnotNone:
ifranks[pipeline_model_parallel_split_rank] notinembedding_ranks:
embedding_ranks= [ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
ifranks[pipeline_model_parallel_split_rank] notinposition_embedding_ranks:
position_embedding_ranks= [ranks[0],
ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks=ranks
position_embedding_ranks=ranks
group=torch.distributed.new_group(embedding_ranks)
ifrankinembedding_ranks:
_EMBEDDING_GROUP=group
ifrankinranks:
_EMBEDDING_GLOBAL_RANKS=embedding_ranks
group=torch.distributed.new_group(position_embedding_ranks)
ifrankinposition_embedding_ranks:
_POSITION_EMBEDDING_GROUP=group
ifrankinranks:
_POSITION_EMBEDDING_GLOBAL_RANKS=position_embedding_ranks
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
# put this. If we end up with a more generic initialization of megatron-core
# we could stick it there
_set_global_memory_buffer()
defmodel_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if_TENSOR_MODEL_PARALLEL_GROUPisNoneor \
_PIPELINE_MODEL_PARALLEL_GROUPisNoneor \
_DATA_PARALLEL_GROUPisNone:
returnFalse
returnTrue
defget_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert_MODEL_PARALLEL_GROUPisnotNone, \
'model parallel group is not initialized'
return_MODEL_PARALLEL_GROUP
defget_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert_TENSOR_MODEL_PARALLEL_GROUPisnotNone, \
'intra_layer_model parallel group is not initialized'
return_TENSOR_MODEL_PARALLEL_GROUP
defget_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert_PIPELINE_MODEL_PARALLEL_GROUPisnotNone, \
'pipeline_model parallel group is not initialized'
return_PIPELINE_MODEL_PARALLEL_GROUP
defget_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert_DATA_PARALLEL_GROUPisnotNone, \
'data parallel group is not initialized'
return_DATA_PARALLEL_GROUP
defget_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert_EMBEDDING_GROUPisnotNone, \
'embedding group is not initialized'
return_EMBEDDING_GROUP
defget_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert_POSITION_EMBEDDING_GROUPisnotNone, \
'position embedding group is not initialized'
return_POSITION_EMBEDDING_GROUP
defset_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE=world_size
defset_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE=world_size
defget_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZEisnotNone:
return_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
returntorch.distributed.get_world_size(group=get_tensor_model_parallel_group())
defget_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZEisnotNone:
return_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
returntorch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
defset_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK=rank
defset_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK=rank
defset_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK=rank
defget_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global_MPU_TENSOR_MODEL_PARALLEL_RANK
if_MPU_TENSOR_MODEL_PARALLEL_RANKisnotNone:
return_MPU_TENSOR_MODEL_PARALLEL_RANK
returntorch.distributed.get_rank(group=get_tensor_model_parallel_group())
defget_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global_MPU_PIPELINE_MODEL_PARALLEL_RANK
if_MPU_PIPELINE_MODEL_PARALLEL_RANKisnotNone:
return_MPU_PIPELINE_MODEL_PARALLEL_RANK
returntorch.distributed.get_rank(group=get_pipeline_model_parallel_group())
defis_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
ifnotignore_virtual:
ifget_virtual_pipeline_model_parallel_world_size() isnotNoneand \
get_virtual_pipeline_model_parallel_rank() !=0:
returnFalse
returnget_pipeline_model_parallel_rank() ==0
defis_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
ifnotignore_virtual:
virtual_pipeline_model_parallel_world_size= \
get_virtual_pipeline_model_parallel_world_size()
ifvirtual_pipeline_model_parallel_world_sizeisnotNoneand \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size-1):
returnFalse
returnget_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() -1)
defis_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank=torch.distributed.get_rank()
global_EMBEDDING_GLOBAL_RANKS
ifignore_virtual:
returnrankin_EMBEDDING_GLOBAL_RANKS
ifrankin_EMBEDDING_GLOBAL_RANKS:
ifrank==_EMBEDDING_GLOBAL_RANKS[0]:
returnis_pipeline_first_stage(ignore_virtual=False)
elifrank==_EMBEDDING_GLOBAL_RANKS[-1]:
returnis_pipeline_last_stage(ignore_virtual=False)
else:
returnTrue
returnFalse
defis_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank=torch.distributed.get_rank()
global_POSITION_EMBEDDING_GLOBAL_RANKS
returnrankin_POSITION_EMBEDDING_GLOBAL_RANKS
defis_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
ifget_pipeline_model_parallel_world_size() ==1:
returnTrue
ifrankisNone:
rank=get_pipeline_model_parallel_rank()
global_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if_PIPELINE_MODEL_PARALLEL_SPLIT_RANKisNone:
returnTrue
ifrank<_PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
returnTrue
returnFalse
defis_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
ifget_pipeline_model_parallel_world_size() ==1:
returnTrue
ifrankisNone:
rank=get_pipeline_model_parallel_rank()
global_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if_PIPELINE_MODEL_PARALLEL_SPLIT_RANKisNone:
returnTrue
ifrank>=_PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
returnTrue
returnFalse
defis_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank=get_pipeline_model_parallel_rank()
returnis_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
defget_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
defset_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK=rank
defget_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
defget_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank=torch.distributed.get_rank()
local_world_size=get_tensor_model_parallel_world_size()
return (global_rank//local_world_size) *local_world_size
defget_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
assert_DATA_PARALLEL_GLOBAL_RANKSisnotNone, \
"Data parallel group is not initialized"
return_DATA_PARALLEL_GLOBAL_RANKS[0]
defget_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert_PIPELINE_GLOBAL_RANKSisnotNone, \
"Pipeline parallel group is not initialized"
return_PIPELINE_GLOBAL_RANKS[0]
defget_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert_PIPELINE_GLOBAL_RANKSisnotNone, \
"Pipeline parallel group is not initialized"
last_rank_local=get_pipeline_model_parallel_world_size() -1
return_PIPELINE_GLOBAL_RANKS[last_rank_local]
defget_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert_PIPELINE_GLOBAL_RANKSisnotNone, \
"Pipeline parallel group is not initialized"
rank_in_pipeline=get_pipeline_model_parallel_rank()
world_size=get_pipeline_model_parallel_world_size()
return_PIPELINE_GLOBAL_RANKS[(rank_in_pipeline+1) %world_size]
defget_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert_PIPELINE_GLOBAL_RANKSisnotNone, \
"Pipeline parallel group is not initialized"
rank_in_pipeline=get_pipeline_model_parallel_rank()
world_size=get_pipeline_model_parallel_world_size()
return_PIPELINE_GLOBAL_RANKS[(rank_in_pipeline-1) %world_size]
defget_data_parallel_world_size():
"""Return world size for the data parallel group."""
returntorch.distributed.get_world_size(group=get_data_parallel_group())
defget_data_parallel_rank():
"""Return my rank for the data parallel group."""
returntorch.distributed.get_rank(group=get_data_parallel_group())
def_set_global_memory_buffer():
"""Initialize global buffer"""
global_GLOBAL_MEMORY_BUFFER
assert_GLOBAL_MEMORY_BUFFERisNone, 'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER=GlobalMemoryBuffer()
defget_global_memory_buffer():
"""Return the global GlobalMemoryBuffer object"""
assert_GLOBAL_MEMORY_BUFFERisnotNone, 'global memory buffer is not initialized'
return_GLOBAL_MEMORY_BUFFER
defdestroy_model_parallel():
"""Set the groups to none."""
global_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP=None
global_TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP=None
global_PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP=None
global_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP=None
global_EMBEDDING_GROUP
_EMBEDDING_GROUP=None
global_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP=None
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK=None
global_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE=None
global_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE=None
global_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE=None
global_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK=None
global_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK=None
global_GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER=None