- Notifications
You must be signed in to change notification settings - Fork 99
/
Copy pathyolov8_pose_e2e.py
370 lines (309 loc) · 18.6 KB
/
yolov8_pose_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
importargparse
importonnx.shape_inference
importonnxruntime_extensions
fromonnxruntime_extensions.tools.pre_post_processingimport*
frompathlibimportPath
fromPILimportImage, ImageDraw
def_get_yolov8_pose_model(onnx_model_path: Path):
# install yolov8
frompip._internalimportmainaspipmain
try:
importultralytics
exceptImportError:
pipmain(['install', 'ultralytics'])
importultralytics
pt_model=Path("yolov8n-pose.pt")
model=ultralytics.YOLO(str(pt_model)) # load a pretrained model
success=model.export(format="onnx") # export the model to ONNX format
assertsuccess, "Failed to export yolov8n-pose.pt to onnx"
importshutil
shutil.move(pt_model.with_suffix('.onnx'), str(onnx_model_path))
def_get_model_and_info(input_model_path: Path):
ifnotinput_model_path.is_file():
print(f"Fetching the model... {str(input_model_path)}")
_get_yolov8_pose_model(input_model_path)
print("Adding pre/post processing to the model...")
model=onnx.load(str(input_model_path.resolve(strict=True)))
model_with_shape_info=onnx.shape_inference.infer_shapes(model)
model_input_shape=model_with_shape_info.graph.input[0].type.tensor_type.shape
model_output_shape=model_with_shape_info.graph.output[0].type.tensor_type.shape
# infer the input sizes from the model.
w_in=model_input_shape.dim[-1].dim_value
h_in=model_input_shape.dim[-2].dim_value
assertw_in==640andh_in==640# expected values
# output should be [1, 56, 8400].
classes_masks_out=model_output_shape.dim[1].dim_value
boxes_out=model_output_shape.dim[2].dim_value
assertclasses_masks_out==56
assertboxes_out==8400
return (model, w_in, h_in)
def_update_model(model: onnx.ModelProto, output_model_path: Path, pipeline: PrePostProcessor):
"""
Update the model by running the pre/post processing pipeline
@param model: ONNX model to update
@param output_model_path: Filename to write the updated model to.
@param pipeline: Pre/Post processing pipeline to run.
"""
new_model=pipeline.run(model)
print("Pre/post proceessing added.")
# run shape inferencing to validate the new model. shape inferencing will fail if any of the new node
# types or shapes are incorrect. infer_shapes returns a copy of the model with ValueInfo populated,
# but we ignore that and save new_model as it is smaller due to not containing the inferred shape information.
_=onnx.shape_inference.infer_shapes(new_model, strict_mode=True)
onnx.save_model(new_model, str(output_model_path.resolve()))
print("Updated model saved.")
def_add_pre_post_processing_to_rgb_input(input_model_path: Path,
output_model_path: Path,
input_shape: List[Union[int, str]]):
"""
Add pre and post processing with model input of RGB data.
Pre-processing will convert the input to the correct height, width and data type for the model.
Post-processing will select the best bounding boxes using NonMaxSuppression, and scale the selected bounding
boxes and key-points to the original image size.
@param input_model_path: Path to ONNX model.
@param output_model_path: Path to write updated model to.
@param input_shape: Input shape of RGB data. Must be 3D. First or last value must be 3 (channels first or last).
"""
model, w_in, h_in=_get_model_and_info(input_model_path)
ifinput_shape[0] ==3:
layout="CHW"
elifinput_shape[2] ==3:
layout="HWC"
else:
raiseValueError("Invalid input shape. Either first or last dimension must be 3.")
onnx_opset=18
inputs= [create_named_value("rgb_data", onnx.TensorProto.UINT8, input_shape)]
pipeline=PrePostProcessor(inputs, onnx_opset)
iflayout=="CHW":
# use Identity so we have an output named RGBImageCHW
# for ScaleNMSBoundingBoxesAndKeyPoints in the post-processing steps
pre_processing_steps= [Identity(name="RGBImageCHW")]
else:
pre_processing_steps= [ChannelsLastToChannelsFirst(name="RGBImageCHW")] # HWC to CHW
pre_processing_steps+= [
# Resize to match model input. Uses not_larger as we use LetterBox to pad as needed.
Resize((h_in, w_in), policy='not_larger', layout='CHW'),
LetterBox(target_shape=(h_in, w_in), layout='CHW'), # padding or cropping the image to (h_in, w_in)
ImageBytesToFloat(), # Convert to float in range 0..1
Unsqueeze([0]), # add batch, CHW --> 1CHW
]
pipeline.add_pre_processing(pre_processing_steps)
post_processing_steps= [
Squeeze([0]), # - Squeeze to remove batch dimension from [batch, 56, 8200] output
Transpose([1, 0]), # reverse so result info is inner dim
# split the 56 elements into 4 for the box, score for the 1 class, and mask info (17 locations x 3 values)
Split(num_outputs=3, axis=1, splits=[4, 1, 51]),
# Apply NMS to select best boxes. iou and score values match
# https://github.com/ultralytics/ultralytics/blob/e7bd159a44cf7426c0f33ed9b413ef4439505a03/ultralytics/models/yolo/pose/predict.py#L34-L35
# thresholds are arbitrarily chosen. adjust as needed.
SelectBestBoundingBoxesByNMS(iou_threshold=0.7, score_threshold=0.25, has_mask_data=True),
# Scale boxes and key point coords back to original image. Mask data has 17 key points per box.
(ScaleNMSBoundingBoxesAndKeyPoints(num_key_points=17, layout='CHW'),
[
# A default connection from SelectBestBoundingBoxesByNMS for input 0
# A connection from original image to input 1
# A connection from the resized image to input 2
# A connection from the LetterBoxed image to input 3
# We use the three images to calculate the scale factor and offset.
# With scale and offset, we can scale the bounding box and key points back to the original image.
utils.IoMapEntry("RGBImageCHW", producer_idx=0, consumer_idx=1),
utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
]),
]
pipeline.add_post_processing(post_processing_steps)
_update_model(model, output_model_path, pipeline)
def_add_pre_post_processing_to_image_input(input_model_path: Path,
output_model_path: Path,
output_image_format: Optional[str]):
"""
Add pre and post processing with model input of jpg or png image bytes.
Pre-processing will convert the input to the correct height, width and data type for the model.
Post-processing will select the best bounding boxes using NonMaxSuppression, and scale the selected bounding
boxes and key-points to the original image size.
The post-processing can alternatively return the original image with the bounding boxes drawn on it
instead of the scaled bounding box and key point data.
@param input_model_path: Path to ONNX model.
@param output_model_path: Path to write updated model to.
@param output_image_format: Optional. Specify 'jpg' or 'png' for the post-processing to return image bytes in that
format with the bounding boxes drawn on it.
Otherwise the model will return the scaled bounding boxes and key points.
"""
model, w_in, h_in=_get_model_and_info(input_model_path)
onnx_opset=18
inputs= [create_named_value("image_bytes", onnx.TensorProto.UINT8, ["num_bytes"])]
pipeline=PrePostProcessor(inputs, onnx_opset)
pre_processing_steps= [
ConvertImageToBGR(name="BGRImageHWC"), # jpg/png image to BGR in HWC layout
ChannelsLastToChannelsFirst(name="BGRImageCHW"), # HWC to CHW
# Resize to match model input. Uses not_larger as we use LetterBox to pad as needed.
Resize((h_in, w_in), policy='not_larger', layout='CHW'),
LetterBox(target_shape=(h_in, w_in), layout='CHW'), # padding or cropping the image to (h_in, w_in)
ImageBytesToFloat(), # Convert to float in range 0..1
Unsqueeze([0]), # add batch, CHW --> 1CHW
]
pipeline.add_pre_processing(pre_processing_steps)
# NonMaxSuppression and drawing boxes
post_processing_steps= [
Squeeze([0]), # Squeeze to remove batch dimension from [batch, 56, 8200] output
Transpose([1, 0]), # reverse so result info is inner dim
# split the 56 elements into the box, score for the 1 class, and mask info (17 locations x 3 values)
Split(num_outputs=3, axis=1, splits=[4, 1, 51]),
# Apply NMS to select best boxes. iou and score values match
# https://github.com/ultralytics/ultralytics/blob/e7bd159a44cf7426c0f33ed9b413ef4439505a03/ultralytics/models/yolo/pose/predict.py#L34-L35
# thresholds are arbitrarily chosen. adjust as needed
SelectBestBoundingBoxesByNMS(iou_threshold=0.7, score_threshold=0.25, has_mask_data=True),
# Scale boxes and key point coords back to original image. Mask data has 17 key points per box.
(ScaleNMSBoundingBoxesAndKeyPoints(num_key_points=17, layout='CHW'),
[
# A default connection from SelectBestBoundingBoxesByNMS for input 0
# A connection from original image to input 1
# A connection from the resized image to input 2
# A connection from the LetterBoxed image to input 3
# We use the three images to calculate the scale factor and offset.
# With scale and offset, we can scale the bounding box and key points back to the original image.
utils.IoMapEntry("BGRImageCHW", producer_idx=0, consumer_idx=1),
utils.IoMapEntry("Resize", producer_idx=0, consumer_idx=2),
utils.IoMapEntry("LetterBox", producer_idx=0, consumer_idx=3),
]),
]
ifoutput_image_format:
post_processing_steps+= [
# split out bounding box from keypoint data
Split(num_outputs=2, axis=-1, splits=[6, 51], name="SplitScaledBoxesAndKeypoints"),
# separate out the bounding boxes from the keypoint data to use the existing steps/custom op to draw the
# bounding boxes.
(DrawBoundingBoxes(mode='CENTER_XYWH', num_classes=1, colour_by_classes=True),
[
utils.IoMapEntry("BGRImageHWC", producer_idx=0, consumer_idx=0),
utils.IoMapEntry("SplitScaledBoxesAndKeypoints", producer_idx=0, consumer_idx=1),
]),
# Encode to jpg/png
ConvertBGRToImage(image_format=output_image_format),
]
pipeline.add_post_processing(post_processing_steps)
print("Updating model ...")
_update_model(model, output_model_path, pipeline)
def_run_inference(onnx_model_path: Path, model_input: str, model_outputs_image: bool, test_image: Path,
rgb_layout: Optional[str]):
importonnxruntimeasort
importnumpyasnp
print(f"Running the model to validate output using {str(test_image)}.")
providers= ['CPUExecutionProvider']
session_options=ort.SessionOptions()
session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path())
session=ort.InferenceSession(str(onnx_model_path), providers=providers, sess_options=session_options)
input_name= [i.nameforiinsession.get_inputs()]
ifmodel_input=="image":
image_bytes=np.frombuffer(open(test_image, 'rb').read(), dtype=np.uint8)
model_input= {input_name[0]: image_bytes}
else:
rgb_image=np.array(Image.open(test_image).convert('RGB'))
ifrgb_layout=="CHW":
rgb_image=rgb_image.transpose((2, 0, 1)) # Channels first
model_input= {input_name[0]: rgb_image}
model_output= ['image'] ifmodel_outputs_imageelse ['nms_output_with_scaled_boxes_and_keypoints']
outputs=session.run(model_output, model_input)
ifmodel_outputs_image:
# jpg or png with bounding boxes draw
image_out=outputs[0]
fromioimportBytesIO
s=BytesIO(image_out)
Image.open(s).show()
else:
# manually draw the bounding boxes and skeleton just to prove it works
skeleton= [[16, 14], [14, 12], [17, 15], [15, 13], [12, 13], [6, 12], [7, 13], [6, 7], [6, 8], [7, 9],
[8, 10], [9, 11], [2, 3], [1, 2], [1, 3], [2, 4], [3, 5], [4, 6], [5, 7]]
# open original image so we can draw on it
input_image=Image.open(test_image).convert('RGB')
input_image_draw=ImageDraw.Draw(input_image)
scaled_nms_output=outputs[0]
forresultinscaled_nms_output:
# split the 4 box coords, 1 score, 1 class (ignored), keypoints
(box, score, _, keypoints) =np.split(result, (4, 5, 6))
keypoints=keypoints.reshape((17, 3))
# convert box from centered XYWH to co-ords and draw rectangle
# NOTE: The pytorch model seems to output XYXY co-ords. Not sure why that's different.
half_w= (box[2] /2)
half_h= (box[3] /2)
x0=box[0] -half_w
y0=box[1] -half_h
x1=box[0] +half_w
y1=box[1] +half_h
input_image_draw.rectangle(((x0, y0), (x1, y1)), outline='red', width=4)
# draw skeleton
# See https://github.com/ultralytics/ultralytics/blob/e7bd159a44cf7426c0f33ed9b413ef4439505a03/ultralytics/utils/plotting.py#L171
fori, skinenumerate(skeleton):
# convert keypoint index in `skeleton` to 0-based index and get keypoint data for it
keypoint1=keypoints[sk[0] -1]
keypoint2=keypoints[sk[1] -1]
pos1= (int(keypoint1[0]), int(keypoint1[1]))
pos2= (int(keypoint2[0]), int(keypoint2[1]))
conf1=keypoint1[2]
conf2=keypoint2[2]
ifconf1<0.5orconf2<0.5:
continue
defcoord_valid(coord):
x, y=coord
return0<=x<input_image.widthand0<=y<input_image.height
ifcoord_valid(pos1) andcoord_valid(pos2):
input_image_draw.line((pos1, pos2), fill='yellow', width=2)
print("Displaying original image with bounding boxes and skeletons.")
input_image.show()
if__name__=='__main__':
parser=argparse.ArgumentParser(
"""Add pre and post processing to the YOLOv8 POSE model. The model can be updated to take either
jpg/png bytes as input (--input image), or RGB data (--input rgb).
By default the post processing will scale the bounding boxes and key points to the original image.
""")
parser.add_argument("--onnx_model_path", type=Path, default="yolov8n-pose.onnx",
help="The ONNX YOLOv8 POSE model.")
parser.add_argument("--updated_onnx_model_path", type=Path, required=False,
help="Filename to save the updated ONNX model to. If not provided default to the filename "
"from --onnx_model_path with '.with_pre_post_processing' before the '.onnx' "
"e.g. yolov8n-pose.onnx -> yolov8n-pose.with_pre_post_processing.onnx")
parser.add_argument("--input", choices=("image", "rgb"), default="image",
help="Desired model input format. Image bytes from jpg/png or RGB data.")
parser.add_argument("--input_shape",
type=lambdax: [int(dim) ifdim.isnumeric() elsedimfordiminx.split(",")],
required=False,
help="Shape of RGB input if input is 'rgb'. Provide a comma separated list of 3 dimensions. "
"Symbolic dimensions are allowed. Either the first or last dimension must be 3 to infer "
"if layout is HWC or CHW. "
"examples: channels first with symbolic dims for height and width: --input_shape 3,H,W "
"or channels last with fixed input shape: --input_shape 384,512,3")
parser.add_argument("--output_image", choices=("jpg", "png"), required=False,
help="OPTIONAL. If the input is an image, instead of outputting the scaled bounding boxes and "
"key points the model will draw the bounding boxes on the original image, convert to the "
"specified format, and output the updated image bytes. The scaled key points for each "
"selected bounding box will also be a model output."
"NOTE: it will NOT draw the key points as there's no custom operator to handle that.")
parser.add_argument("--run_model", action='store_true',
help="Run inference on the model to validate output.")
parser.add_argument("--test_image", type=Path, default="data/stormtroopers.jpg",
help="JPG or PNG image to run model with.")
args=parser.parse_args()
ifargs.output_imageandargs.input=="rgb":
raiseargparse.ArgumentError(args.output_image, "output_image argument can only be used if input is 'image'")
ifargs.input_shapeandlen(args.input_shape) !=3:
raiseargparse.ArgumentError(args.input_shape, "Shape of RGB input must have 3 dimensions.")
updated_model_path= (args.updated_onnx_model_path
ifargs.updated_onnx_model_path
elseargs.onnx_model_path.with_suffix(suffix=".with_pre_post_processing.onnx"))
# default output is the scaled non-max suppression data which matches the original model.
# each result has bounding box (4), score (1), class (1), key points (17 x 3) = 57 elements
# bounding box is centered XYWH format.
# alternative is to output the original image with the bounding boxes but no key points drawn.
ifargs.input=="rgb":
print("Updating model with RGB data as input.")
_add_pre_post_processing_to_rgb_input(args.onnx_model_path, updated_model_path, args.input_shape)
rgb_layout="CHW"ifargs.input_shape[0] ==3else"HWC"
else:
assert(args.input=="image")
print("Updating model with jpg/png image bytes as input.")
_add_pre_post_processing_to_image_input(args.onnx_model_path, updated_model_path, args.output_image)
rgb_layout=None
ifargs.run_model:
_run_inference(updated_model_path, args.input, args.output_imageisnotNone, args.test_image, rgb_layout)