- Notifications
You must be signed in to change notification settings - Fork 99
/
Copy pathwhisper_e2e.py
135 lines (110 loc) · 5.23 KB
/
whisper_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Run the whisper end-to-end inference with ONNXRuntime-Extensions for pre/post processing.
# THIS SCRIPT IS USED TO DEMO ONLY, WHICH IS NOT A PART OF THE PACKAGE.
# TO GENERATE THE FULL-FUNCTION MODEL, PLEASE USE https://github.com/microsoft/Olive
importos
importonnx
importsubprocess
importnumpyasnp
importonnxruntimeasort
frompackagingimportversion
fromtransformersimportWhisperProcessor
fromonnxruntime_extensionsimportOrtPyFunction, util
fromonnxruntime_extensions.cvtimportgen_processing_models
# Constants
MODEL_NAME="openai/whisper-tiny.en"
CACHE_DIR='temp_caches_onnx'
OUTPUT_DIR='temp_model_onnx'
FINAL_MODEL="whisper_onnx_tiny_en_fp32_e2e.onnx"
TEST_AUDIO_FILE=util.get_test_data_file('../test/data', "1272-141231-0002.mp3")
defcheck_onnx_version():
ifversion.parse(ort.__version__) <version.parse("1.16.0"):
raiseRuntimeError("ONNXRuntime version must >= 1.16.0")
defexport_onnx_model():
print("Exporting Whisper ONNX model from Huggingface model hub...")
command= ['python', '-m',
'onnxruntime.transformers.models.whisper.convert_to_onnx',
'-m', MODEL_NAME,
'--cache_dir', CACHE_DIR,
'--output', OUTPUT_DIR,
'--precision', 'fp32']
process=subprocess.run(command)
ifprocess.returncode!=0:
raiseRuntimeError("Failed to export the core ONNX models.")
defprocess_test_file():
ifnotos.path.exists(TEST_AUDIO_FILE):
raiseFileNotFoundError(f"Test audio path {TEST_AUDIO_FILE} does not exist.")
raw_audio=np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
_processor=WhisperProcessor.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
pre_m, post_m=gen_processing_models(_processor,
pre_kwargs={"USE_AUDIO_DECODER": True, "USE_ONNX_STFT": True},
post_kwargs={},
opset=17)
fn_pre=OrtPyFunction.from_model(pre_m, session_options={"graph_optimization_level": 0})
returnfn_pre(np.expand_dims(raw_audio, axis=0)), pre_m, post_m
defget_model_inputs(ort_session, audio_data):
ort_names=list(map(lambdaentry: entry.name, ort_session.get_inputs()))
print(ort_names)
inputs= [
audio_data, # audio_stream/input_features
np.asarray([200], dtype=np.int32), # max_length
np.asarray([0], dtype=np.int32), # min_length
np.asarray([2], dtype=np.int32), # num_beams
np.asarray([1], dtype=np.int32), # num_return_sequences
np.asarray([1.0], dtype=np.float32), # length_penalty
np.asarray([1.0], dtype=np.float32), # repetition_penalty
]
required_input_names= {"audio_stream", "input_features", "max_length", "min_length", "num_beams",
"num_return_sequences", "length_penalty", "repetition_penalty"}
# Add optional inputs if present in model
batch_size=1
N_MELS=80
N_FRAMES=3000
vocab_size=51864if".en"inMODEL_NAMEelse51865
decoder_start_token_id=50257if".en"inMODEL_NAMEelse50258
fornameinort_names:
ifnameinrequired_input_names:
continue
elifname=="vocab_mask":
inputs.append(np.ones(vocab_size, dtype=np.int32))
elifname=="prefix_vocab_mask":
inputs.append(np.ones((batch_size, vocab_size), dtype=np.int32))
elifname=="attention_mask":
# For older ORT versions that have the dummy attention mask input for the beam search op
inputs.append(np.zeros((batch_size, N_MELS, N_FRAMES), dtype=np.int32))
elifname=="decoder_input_ids":
inputs.append(np.array([[decoder_start_token_id]], dtype=np.int32))
elifname=="logits_processor":
inputs.append(np.array([1], dtype=np.int32))
else:
raiseNotImplementedError(f"'{name}' input is not supported")
returninputs
defmain():
check_onnx_version()
export_onnx_model()
log_mel, pre_m, post_m=process_test_file()
# Apply core ONNX model
fn_core=OrtPyFunction.from_model(os.path.join(OUTPUT_DIR, "whisper-tiny.en_beamsearch.onnx"), cpu_only=True)
fn_core_ort_session=fn_core._ensure_ort_session()
model_inputs=get_model_inputs(fn_core_ort_session, log_mel)
token_seq=fn_core(*model_inputs)
print(token_seq.shape)
# Apply post processing
fn_post=OrtPyFunction.from_model(post_m, cpu_only=True)
output_text=fn_post(token_seq)
print(output_text)
# Merge models and save final model
print("Combine the data processing graphs into the ONNX model...")
final_m=util.quick_merge(pre_m, fn_core.onnx_model, post_m)
onnx.save(final_m, FINAL_MODEL)
# Test the final model
raw_audio=np.fromfile(TEST_AUDIO_FILE, dtype=np.uint8)
raw_audio=np.expand_dims(raw_audio, axis=0)
e2e_model=OrtPyFunction.from_model(final_m, cpu_only=True)
e2e_model_ort_session=e2e_model._ensure_ort_session()
model_inputs=get_model_inputs(e2e_model_ort_session, raw_audio)
text=e2e_model(*model_inputs)
print(text)
if__name__=="__main__":
main()