- Notifications
You must be signed in to change notification settings - Fork 362
/
Copy pathresnet50_data_reader.py
63 lines (54 loc) · 2.41 KB
/
resnet50_data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
importnumpy
importonnxruntime
importos
fromonnxruntime.quantizationimportCalibrationDataReader
fromPILimportImage
def_preprocess_images(images_folder: str, height: int, width: int, size_limit=0):
"""
Loads a batch of images and preprocess them
parameter images_folder: path to folder storing images
parameter height: image height in pixels
parameter width: image width in pixels
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
return: list of matrices characterizing multiple images
"""
image_names=os.listdir(images_folder)
ifsize_limit>0andlen(image_names) >=size_limit:
batch_filenames= [image_names[i] foriinrange(size_limit)]
else:
batch_filenames=image_names
unconcatenated_batch_data= []
forimage_nameinbatch_filenames:
image_filepath=images_folder+"/"+image_name
pillow_img=Image.new("RGB", (width, height))
pillow_img.paste(Image.open(image_filepath).resize((width, height)))
input_data=numpy.float32(pillow_img) -numpy.array(
[123.68, 116.78, 103.94], dtype=numpy.float32
)
nhwc_data=numpy.expand_dims(input_data, axis=0)
nchw_data=nhwc_data.transpose(0, 3, 1, 2) # ONNX Runtime standard
unconcatenated_batch_data.append(nchw_data)
batch_data=numpy.concatenate(
numpy.expand_dims(unconcatenated_batch_data, axis=0), axis=0
)
returnbatch_data
classResNet50DataReader(CalibrationDataReader):
def__init__(self, calibration_image_folder: str, model_path: str):
self.enum_data=None
# Use inference session to get input shape.
session=onnxruntime.InferenceSession(model_path, None)
(_, _, height, width) =session.get_inputs()[0].shape
# Convert image to input data
self.nhwc_data_list=_preprocess_images(
calibration_image_folder, height, width, size_limit=0
)
self.input_name=session.get_inputs()[0].name
self.datasize=len(self.nhwc_data_list)
defget_next(self):
ifself.enum_dataisNone:
self.enum_data=iter(
[{self.input_name: nhwc_data} fornhwc_datainself.nhwc_data_list]
)
returnnext(self.enum_data, None)
defrewind(self):
self.enum_data=None