- Notifications
You must be signed in to change notification settings - Fork 509
/
Copy pathdist_util.py
82 lines (68 loc) · 2.55 KB
/
dist_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
importfunctools
importos
importsubprocess
importtorch
importtorch.distributedasdist
importtorch.multiprocessingasmp
definit_dist(launcher, backend='nccl', **kwargs):
ifmp.get_start_method(allow_none=True) isNone:
mp.set_start_method('spawn')
iflauncher=='pytorch':
_init_dist_pytorch(backend, **kwargs)
eliflauncher=='slurm':
_init_dist_slurm(backend, **kwargs)
else:
raiseValueError(f'Invalid launcher type: {launcher}')
def_init_dist_pytorch(backend, **kwargs):
rank=int(os.environ['RANK'])
num_gpus=torch.cuda.device_count()
torch.cuda.set_device(rank%num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def_init_dist_slurm(backend, port=None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id=int(os.environ['SLURM_PROCID'])
ntasks=int(os.environ['SLURM_NTASKS'])
node_list=os.environ['SLURM_NODELIST']
num_gpus=torch.cuda.device_count()
torch.cuda.set_device(proc_id%num_gpus)
addr=subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
# specify master port
ifportisnotNone:
os.environ['MASTER_PORT'] =str(port)
elif'MASTER_PORT'inos.environ:
pass# use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] ='29500'
os.environ['MASTER_ADDR'] =addr
os.environ['WORLD_SIZE'] =str(ntasks)
os.environ['LOCAL_RANK'] =str(proc_id%num_gpus)
os.environ['RANK'] =str(proc_id)
dist.init_process_group(backend=backend)
defget_dist_info():
ifdist.is_available():
initialized=dist.is_initialized()
else:
initialized=False
ifinitialized:
rank=dist.get_rank()
world_size=dist.get_world_size()
else:
rank=0
world_size=1
returnrank, world_size
defmaster_only(func):
@functools.wraps(func)
defwrapper(*args, **kwargs):
rank, _=get_dist_info()
ifrank==0:
returnfunc(*args, **kwargs)
returnwrapper