- Notifications
You must be signed in to change notification settings - Fork 509
/
Copy pathmisc.py
157 lines (127 loc) · 5.07 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
importos
importre
importrandom
importtime
importtorch
importnumpyasnp
fromosimportpathasosp
from .dist_utilimportmaster_only
from .loggerimportget_root_logger
IS_HIGH_VERSION= [int(m) forminlist(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
defgpu_is_available():
ifIS_HIGH_VERSION:
iftorch.backends.mps.is_available():
returnTrue
returnTrueiftorch.cuda.is_available() andtorch.backends.cudnn.is_available() elseFalse
defget_device(gpu_id=None):
ifgpu_idisNone:
gpu_str=''
elifisinstance(gpu_id, int):
gpu_str=f':{gpu_id}'
else:
raiseTypeError('Input should be int value.')
ifIS_HIGH_VERSION:
iftorch.backends.mps.is_available():
returntorch.device('mps'+gpu_str)
returntorch.device('cuda'+gpu_striftorch.cuda.is_available() andtorch.backends.cudnn.is_available() else'cpu')
defset_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
defget_time_str():
returntime.strftime('%Y%m%d_%H%M%S', time.localtime())
defmkdir_and_rename(path):
"""mkdirs. If path exists, rename it with timestamp and create a new one.
Args:
path (str): Folder path.
"""
ifosp.exists(path):
new_name=path+'_archived_'+get_time_str()
print(f'Path already exists. Rename it to {new_name}', flush=True)
os.rename(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
defmake_exp_dirs(opt):
"""Make dirs for experiments."""
path_opt=opt['path'].copy()
ifopt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
forkey, pathinpath_opt.items():
if ('strict_load'notinkey) and ('pretrain_network'notinkey) and ('resume'notinkey):
os.makedirs(path, exist_ok=True)
defscandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (suffixisnotNone) andnotisinstance(suffix, (str, tuple)):
raiseTypeError('"suffix" must be a string or tuple of strings')
root=dir_path
def_scandir(dir_path, suffix, recursive):
forentryinos.scandir(dir_path):
ifnotentry.name.startswith('.') andentry.is_file():
iffull_path:
return_path=entry.path
else:
return_path=osp.relpath(entry.path, root)
ifsuffixisNone:
yieldreturn_path
elifreturn_path.endswith(suffix):
yieldreturn_path
else:
ifrecursive:
yieldfrom_scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return_scandir(dir_path, suffix=suffix, recursive=recursive)
defcheck_resume(opt, resume_iter):
"""Check resume states and pretrain_network paths.
Args:
opt (dict): Options.
resume_iter (int): Resume iteration.
"""
logger=get_root_logger()
ifopt['path']['resume_state']:
# get all the networks
networks= [keyforkeyinopt.keys() ifkey.startswith('network_')]
flag_pretrain=False
fornetworkinnetworks:
ifopt['path'].get(f'pretrain_{network}') isnotNone:
flag_pretrain=True
ifflag_pretrain:
logger.warning('pretrain_network path will be ignored during resuming.')
# set pretrained model paths
fornetworkinnetworks:
name=f'pretrain_{network}'
basename=network.replace('network_', '')
ifopt['path'].get('ignore_resume_networks') isNoneor (basename
notinopt['path']['ignore_resume_networks']):
opt['path'][name] =osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
logger.info(f"Set {name} to {opt['path'][name]}")
defsizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formated file siz.
"""
forunitin ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
ifabs(size) <1024.0:
returnf'{size:3.1f}{unit}{suffix}'
size/=1024.0
returnf'{size:3.1f} Y{suffix}'