- Notifications
You must be signed in to change notification settings - Fork 633
/
Copy pathtimm_app.py
54 lines (41 loc) · 1.88 KB
/
timm_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Pytorch Image Models (the timm package).
TODO: timm doesn't take in gs:// (fake imagenet data hosted by Google), so this
example finishes up to provisioning and doesn't really start training. An
error on data path being invalid is expected. We should investigate supporting
large datasets.
"""
importos
importsubprocess
importsky
PROJECT_DIR='~/Downloads/pytorch-image-models'
defclone_project():
ifnotos.path.isdir(os.path.expanduser(PROJECT_DIR)):
subprocess.run(
'git clone https://github.com/rwightman/pytorch-image-models {}'.
format(PROJECT_DIR),
shell=True,
check=True)
clone_project()
withsky.Dag() asdag:
# The working directory contains all code and will be synced to remote.
workdir=PROJECT_DIR
# The setup command. Will be run under the working directory.
setup='pip3 install timm pyyaml'
# The command to run. Will be run under the working directory.
# https://rwightman.github.io/pytorch-image-models/training_hparam_examples/
# fake_imagenet (tfrecords) doesn't work:
# RuntimeError: Found 0 images in subfolders of
# /tmp/fake_imagenet. Supported image extensions are .png, .jpg, .jpeg
run='./distributed_train.sh 1 /tmp/fake_imagenet --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016'
train=sky.Task(
'train',
workdir=workdir,
setup=setup,
run=run,
)
train.set_file_mounts({
# Download from GCS.
'/tmp/fake_imagenet': 'gs://cloud-tpu-test-datasets/fake_imagenet',
})
train.set_resources({sky.Resources(sky.AWS(), accelerators='V100')})
sky.launch(dag)