- Notifications
You must be signed in to change notification settings - Fork 99
/
Copy pathsuperresolution_e2e.py
161 lines (124 loc) · 7.03 KB
/
superresolution_e2e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
importio
importnumpyasnp
importonnxruntimeasort
importos
frompathlibimportPath
fromPILimportImage
_this_dirpath=Path(os.path.dirname(os.path.abspath(__file__)))
ONNX_MODEL='pytorch_superresolution.onnx'
ONNX_MODEL_WITH_PRE_POST_PROCESSING='pytorch_superresolution.with_pre_post_processing.onnx'
# Export pytorch superresolution model as per
# https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
defconvert_pytorch_superresolution_to_onnx():
importtorch.utils.model_zooasmodel_zoo
importtorch.onnx
# Super Resolution model definition in PyTorch
importtorch.nnasnn
importtorch.nn.initasinit
classSuperResolutionNet(nn.Module):
def__init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu=nn.ReLU(inplace=inplace)
self.conv1=nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2=nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3=nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4=nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle=nn.PixelShuffle(upscale_factor)
self._initialize_weights()
defforward(self, x):
x=self.relu(self.conv1(x))
x=self.relu(self.conv2(x))
x=self.relu(self.conv3(x))
x=self.pixel_shuffle(self.conv4(x))
returnx
def_initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
init.orthogonal_(self.conv4.weight)
# Create the super-resolution model by using the above model definition.
torch_model=SuperResolutionNet(upscale_factor=3)
# Load pretrained model weights
model_url='https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size=1# fix batch size to 1 for use in mobile scenarios
# Initialize model with the pretrained weights
map_location=lambdastorage, loc: storage
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
# set the model to inference mode
torch_model.eval()
# Create random input to the model and run it
x=torch.randn(batch_size, 1, 224, 224, requires_grad=True)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
ONNX_MODEL, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output']) # the model's output names
defadd_pre_post_processing(output_format: str="png"):
# Use the pre-defined helper to add pre/post processing to a super resolution model based on YCbCr input.
# Note: if you're creating a custom pre/post processing pipeline, use
# `from onnxruntime_extensions.tools.pre_post_processing import *` to pull in the pre/post processing infrastructure
# and Step definitions.
fromonnxruntime_extensions.toolsimportadd_pre_post_processing_to_modelasadd_ppp
# ORT 1.14 and later support ONNX opset 18, which added antialiasing to the Resize operator.
# Results are much better when that can be used. Minimum opset is 16.
onnx_opset=16
frompackagingimportversion
ifversion.parse(ort.__version__) >=version.parse("1.14.0"):
onnx_opset=18
# add the processing to the model and output a PNG format image. JPG is also valid.
add_ppp.superresolution(Path(ONNX_MODEL), Path(ONNX_MODEL_WITH_PRE_POST_PROCESSING), output_format, onnx_opset)
def_center_crop_to_square(img: Image):
ifimg.height!=img.width:
target_size=img.widthifimg.width<img.heightelseimg.height
w_start=int(np.floor((img.width-target_size) /2))
w_end=w_start+target_size
h_start=int(np.floor((img.height-target_size) /2))
h_end=h_start+target_size
returnimg.crop((w_start, h_start, w_end, h_end))
else:
returnimg
defrun_updated_onnx_model():
fromonnxruntime_extensionsimportget_library_path
so=ort.SessionOptions()
# register the custom operators for the image decode/encode pre/post processing provided by onnxruntime-extensions
# with onnxruntime. if we do not do this we'll get an error on model load about the operators not being found.
ortext_lib_path=get_library_path()
so.register_custom_ops_library(ortext_lib_path)
inference_session=ort.InferenceSession(ONNX_MODEL_WITH_PRE_POST_PROCESSING, so)
test_image_path=_this_dirpath/'data'/'super_res_input.png'
test_image_bytes=np.fromfile(test_image_path, dtype=np.uint8)
outputs=inference_session.run(['image_out'], {'image': test_image_bytes})
upsized_image_bytes=outputs[0]
original=Image.open(io.BytesIO(test_image_bytes))
updated=Image.open(io.BytesIO(upsized_image_bytes))
# centered crop of original to match the area processed
original_cropped=_center_crop_to_square(original)
returnoriginal_cropped, updated
if__name__=='__main__':
# check onnxruntime-extensions version
importonnxruntime_extensions
frompackagingimportversion
ifversion.parse(onnxruntime_extensions.__version__) <version.parse("0.6.0"):
# temporarily install using nightly until we have official release on pypi.
raiseImportError(
f"onnxruntime_extensions version 0.6 or later is required. {onnxruntime_extensions.__version__} is installed. "
"Please install the latest version using "
"`pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ onnxruntime-extensions`") # noqa
convert_pytorch_superresolution_to_onnx()
add_pre_post_processing('png')
original_img, updated_img=run_updated_onnx_model()
new_width, new_height=updated_img.size
# create a side-by-side image with both.
# do a plain resize of original to model input size followed by model output size
# so side-by-side is an easier comparison
resized_orig_img=original_img.resize((224, 224))
resized_orig_img=resized_orig_img.resize((new_width, new_height))
combined=Image.new('RGB', (new_width*2, new_height))
combined.paste(resized_orig_img, (0, 0))
combined.paste(updated_img, (new_width, 0))
# NOTE: The output is significantly better with ONNX opset 18 as Resize supports anti-aliasing.
combined.show('Original resized vs original vs Super Resolution resized')