- Notifications
You must be signed in to change notification settings - Fork 175
/
Copy pathphi3-qa.py
99 lines (80 loc) · 4.78 KB
/
phi3-qa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
importonnxruntime_genaiasog
importargparse
importtime
defmain(args):
ifargs.verbose: print("Loading model...")
ifargs.timings:
started_timestamp=0
first_token_timestamp=0
config=og.Config(args.model_path)
ifargs.execution_provider!="follow_config":
config.clear_providers()
ifargs.execution_provider!="cpu":
ifargs.verbose: print(f"Setting model to {args.execution_provider}")
config.append_provider(args.execution_provider)
model=og.Model(config)
ifargs.verbose: print("Model loaded")
tokenizer=og.Tokenizer(model)
tokenizer_stream=tokenizer.create_stream()
ifargs.verbose: print("Tokenizer created")
ifargs.verbose: print()
search_options= {name:getattr(args, name) fornamein ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] ifnameinargs}
# Set the max length to something sensible by default, unless it is specified by the user,
# since otherwise it will be set to the entire context length
if'max_length'notinsearch_options:
search_options['max_length'] =2048
chat_template='<|user|>\n{input} <|end|>\n<|assistant|>'
# Keep asking for input prompts in a loop
whileTrue:
text=input("Input: ")
ifnottext:
print("Error, input cannot be empty")
continue
ifargs.timings: started_timestamp=time.time()
# If there is a chat template, use it
prompt=f'{chat_template.format(input=text)}'
input_tokens=tokenizer.encode(prompt)
params=og.GeneratorParams(model)
params.set_search_options(**search_options)
generator=og.Generator(model, params)
generator.append_tokens(input_tokens)
ifargs.verbose: print("Generator created")
ifargs.verbose: print("Running generation loop ...")
ifargs.timings:
first=True
new_tokens= []
print()
print("Output: ", end='', flush=True)
try:
whilenotgenerator.is_done():
generator.generate_next_token()
ifargs.timings:
iffirst:
first_token_timestamp=time.time()
first=False
new_token=generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end='', flush=True)
ifargs.timings: new_tokens.append(new_token)
exceptKeyboardInterrupt:
print(" --control+c pressed, aborting generation--")
print()
print()
ifargs.timings:
prompt_time=first_token_timestamp-started_timestamp
run_time=time.time() -first_token_timestamp
print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")
if__name__=="__main__":
parser=argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
parser.add_argument('-m', '--model_path', type=str, required=True, help='Onnx model folder path (must contain genai_config.json and model.onnx)')
parser.add_argument('-e', '--execution_provider', type=str, required=False, default='follow_config', choices=["cpu", "cuda", "dml", "follow_config"], help="Execution provider to run the ONNX Runtime session with. Defaults to follow_config that uses the execution provider listed in the genai_config.json instead.")
parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
args=parser.parse_args()
main(args)