- Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbatch_predictor.py
102 lines (77 loc) · 3.58 KB
/
batch_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# General
importargparse
importos
fromhuggingface_hubimportlogin
# Serving
importdatasets
importtransformers
importnumpyasnp
importtorch
fromtransformersimportAutoTokenizer, AutoModelForCausalLM
fromtransformers.pipelinesimportpipeline
# Ray
importray
# Settings
datasets.disable_progress_bar()
# Variables
base_model_path="google/gemma-2b-it"
# helpers
defget_args():
parser=argparse.ArgumentParser(description='Batch prediction with Gemma on Ray on Vertex AI')
parser.add_argument('--tuned_model_path', type=str, help='path of adapter model')
parser.add_argument('--num_gpus', type=int, default=1, help='number of gpus')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--sample_size', type=int, default=20, help='number of articles to summarize')
parser.add_argument('--temperature', type=float, default=0.1, help='temperature for generating summaries')
parser.add_argument('--max_new_tokens', type=int, default=50, help='max new token for generating summaries')
parser.add_argument('--output_dir', type=str, help='output directory for predictions')
args=parser.parse_args()
returnargs
defmain():
# Set configuration
args=get_args()
config=vars(args)
# Setting training
login(token=os.environ['HF_TOKEN'], add_to_git_credential=True)
transformers.set_seed(8)
# Load dataset
dataset_id="xsum"
sample_size=config["sample_size"]
input_data=datasets.load_dataset(dataset_id, split="validation", trust_remote_code=True)
input_data=input_data.select(range(sample_size))
ray_input_data=ray.data.from_huggingface(input_data)
# Generate predictions
classSummarizer:
def__init__(self):
self.tokenizer=AutoTokenizer.from_pretrained(base_model_path)
self.tokenizer.padding_side="right"
self.tuned_model=AutoModelForCausalLM.from_pretrained(config["tuned_model_path"],
device_map='auto',
torch_dtype=torch.float16)
self.pipeline=pipeline("text-generation",
model=self.tuned_model,
tokenizer=self.tokenizer,
max_new_tokens=config["max_new_tokens"])
def__call__(self, batch: np.ndarray):
# prepare dataset
messages= [{"role": "user",
"content": f"Summarize the following ARTICLE in one sentence.\n###ARTICLE: {document}"}
fordocumentinbatch["document"]]
batch['prompt'] = [self.tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
formessageinmessages]
# generate
batch['generated_summary'] = [self.pipeline(prompt,
do_sample=True,
temperature=config["temperature"],
add_special_tokens=True)[0]["generated_text"][len(prompt):]
forpromptinbatch['prompt']]
returnbatch
predictions_data=ray_input_data.map_batches(
Summarizer,
concurrency=config["num_gpus"],
num_gpus=1,
batch_size=config['batch_size'])
# Store resulting predictions
predictions_data.write_json(config["output_dir"], try_create_dir=True)
if__name__=="__main__":
main()