- Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
74 lines (60 loc) · 2.76 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# libraries
importargparse
# training libraries
fromtrainimporttrain_func
# ray libraries
importray
importray.train.huggingface.transformers
fromray.trainimportScalingConfig, RunConfig, CheckpointConfig
fromray.train.torchimportTorchTrainer
# helpers
defget_args():
parser=argparse.ArgumentParser(description='Supervised tuning Gemma on Ray on Vertex AI')
# some gemma parameters
parser.add_argument("--train_batch_size", type=int, default=1, help="train batch size")
parser.add_argument("--eval_batch_size", type=int, default=1, help="eval batch size")
parser.add_argument("--gradient_accumulation_steps", type=int, default=4, help="gradient accumulation steps")
parser.add_argument("--learning_rate", type=float, default=2e-4, help="learning rate")
parser.add_argument("--max_steps", type=int, default=100, help="max steps")
parser.add_argument("--save_steps", type=int, default=10, help="save steps")
parser.add_argument("--logging_steps", type=int, default=10, help="logging steps")
# ray parameters
parser.add_argument('--num-workers', dest='num_workers', type=int, default=1, help='Number of workers')
parser.add_argument('--use-gpu', dest='use_gpu', action='store_true', default=False, help='Use GPU')
parser.add_argument('--experiment-name', dest='experiment_name', type=str, default='gemma-on-rov', help='Experiment name')
parser.add_argument('--logging-dir', dest='logging_dir', type=str, help='Logging directory')
args=parser.parse_args()
returnargs
defmain():
args=get_args()
config=vars(args)
# initialize ray session
ray.shutdown()
ray.init()
# training config
train_loop_config= {
"per_device_train_batch_size": config['train_batch_size'],
"per_device_eval_batch_size": config['eval_batch_size'],
"gradient_accumulation_steps": config['gradient_accumulation_steps'],
"learning_rate": config['learning_rate'],
"max_steps": config['max_steps'],
"save_steps": config['save_steps'],
"logging_steps": config['logging_steps'],
}
scaling_config=ScalingConfig(num_workers=config['num_workers'], use_gpu=config['use_gpu'])
run_config=RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=5,
checkpoint_score_attribute="loss",
checkpoint_score_order="min"),
storage_path=config['logging_dir'],
name=config['experiment_name'])
trainer=TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config=train_loop_config,
run_config=run_config,
scaling_config=scaling_config
)
# train
result=trainer.fit()
ray.shutdown()
if__name__=="__main__":
main()