- Notifications
You must be signed in to change notification settings - Fork 99
/
Copy pathyolo_e2e.py
69 lines (55 loc) · 2.61 KB
/
yolo_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
importnumpy
frompathlibimportPath
importonnxruntime_extensions
defget_yolo_model(version: int, onnx_model_name: str):
# install yolov8
frompip._internalimportmainaspipmain
try:
importultralytics
exceptImportError:
pipmain(['install', 'ultralytics'])
importultralytics
pt_model=Path(f"yolov{version}n.pt")
model=ultralytics.YOLO(str(pt_model)) # load a pretrained model
exported_filename=model.export(format="onnx") # export the model to ONNX format
assertexported_filename, f"Failed to export yolov{version}n.pt to onnx"
importshutil
shutil.move(exported_filename, onnx_model_name)
defadd_pre_post_processing_to_yolo(input_model_file: Path, output_model_file: Path):
"""Construct the pipeline for an end2end model with pre and post processing.
The final model can take raw image binary as inputs and output the result in raw image file.
Args:
input_model_file (Path): The onnx yolo model.
output_model_file (Path): where to save the final onnx model.
"""
fromonnxruntime_extensions.toolsimportadd_pre_post_processing_to_modelasadd_ppp
add_ppp.yolo_detection(input_model_file, output_model_file, "jpg", onnx_opset=18)
defrun_inference(onnx_model_file: Path):
importonnxruntimeasort
importnumpyasnp
providers= ['CPUExecutionProvider']
session_options=ort.SessionOptions()
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
image=np.frombuffer(open('../test/data/ppp_vision/wolves.jpg', 'rb').read(), dtype=np.uint8)
session=ort.InferenceSession(str(onnx_model_file), providers=providers, sess_options=session_options)
inname= [i.nameforiinsession.get_inputs()]
inp= {inname[0]: image}
output=session.run(['image_out'], inp)[0]
output_filename='../test/data/result.jpg'
open(output_filename, 'wb').write(output)
fromPILimportImage
Image.open(output_filename).show()
if__name__=='__main__':
# YOLO version. Tested with 5 and 8.
version=8
onnx_model_name=Path(f"../test/data/yolov{version}n.onnx")
ifnotonnx_model_name.exists():
print("Fetching original model...")
get_yolo_model(version, str(onnx_model_name))
onnx_e2e_model_name=onnx_model_name.with_suffix(suffix=".with_pre_post_processing.onnx")
print("Adding pre/post processing...")
add_pre_post_processing_to_yolo(onnx_model_name, onnx_e2e_model_name)
print("Testing updated model...")
run_inference(onnx_e2e_model_name)