- Notifications
You must be signed in to change notification settings - Fork 45.6k
/
Copy pathconvert_to_tflite.py
112 lines (93 loc) · 4.18 KB
/
convert_to_tflite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools to convert a quantized deeplab model to tflite."""
fromabslimportapp
fromabslimportflags
importnumpyasnp
fromPILimportImage
importtensorflowastf
flags.DEFINE_string('quantized_graph_def_path', None,
'Path to quantized graphdef.')
flags.DEFINE_string('output_tflite_path', None, 'Output TFlite model path.')
flags.DEFINE_string(
'input_tensor_name', None,
'Input tensor to TFlite model. This usually should be the input tensor to '
'model backbone.'
)
flags.DEFINE_string(
'output_tensor_name', 'ArgMax:0',
'Output tensor name of TFlite model. By default we output the raw semantic '
'label predictions.'
)
flags.DEFINE_string(
'test_image_path', None,
'Path to an image to test the consistency between input graphdef / '
'converted tflite model.'
)
FLAGS=flags.FLAGS
defconvert_to_tflite(quantized_graphdef,
backbone_input_tensor,
output_tensor):
"""Helper method to convert quantized deeplab model to TFlite."""
withtf.Graph().as_default() asgraph:
tf.graph_util.import_graph_def(quantized_graphdef, name='')
sess=tf.compat.v1.Session()
tflite_input=graph.get_tensor_by_name(backbone_input_tensor)
tflite_output=graph.get_tensor_by_name(output_tensor)
converter=tf.compat.v1.lite.TFLiteConverter.from_session(
sess, [tflite_input], [tflite_output])
converter.inference_type=tf.compat.v1.lite.constants.QUANTIZED_UINT8
input_arrays=converter.get_input_arrays()
converter.quantized_input_stats= {input_arrays[0]: (127.5, 127.5)}
returnconverter.convert()
defcheck_tflite_consistency(graph_def, tflite_model, image_path):
"""Runs tflite and frozen graph on same input, check their outputs match."""
# Load tflite model and check input size.
interpreter=tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details=interpreter.get_input_details()
output_details=interpreter.get_output_details()
height, width=input_details[0]['shape'][1:3]
# Prepare input image data.
withtf.io.gfile.GFile(image_path, 'rb') asf:
image=Image.open(f)
image=np.asarray(image.convert('RGB').resize((width, height)))
image=np.expand_dims(image, 0)
# Output from tflite model.
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
output_tflite=interpreter.get_tensor(output_details[0]['index'])
withtf.Graph().as_default():
tf.graph_util.import_graph_def(graph_def, name='')
withtf.compat.v1.Session() assess:
# Note here the graph will include preprocessing part of the graph
# (e.g. resize, pad, normalize). Given the input image size is at the
# crop size (backbone input size), resize / pad should be an identity op.
output_graph=sess.run(
FLAGS.output_tensor_name, feed_dict={'ImageTensor:0': image})
print('%.2f%% pixels have matched semantic labels.'% (
100*np.mean(output_graph==output_tflite)))
defmain(unused_argv):
withtf.io.gfile.GFile(FLAGS.quantized_graph_def_path, 'rb') asf:
graph_def=tf.compat.v1.GraphDef.FromString(f.read())
tflite_model=convert_to_tflite(
graph_def, FLAGS.input_tensor_name, FLAGS.output_tensor_name)
ifFLAGS.output_tflite_path:
withtf.io.gfile.GFile(FLAGS.output_tflite_path, 'wb') asf:
f.write(tflite_model)
ifFLAGS.test_image_path:
check_tflite_consistency(graph_def, tflite_model, FLAGS.test_image_path)
if__name__=='__main__':
app.run(main)