- Notifications
You must be signed in to change notification settings - Fork 438
/
Copy pathsave_pretrained_model.py
107 lines (93 loc) · 3.96 KB
/
save_pretrained_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# SPDX-License-Identifier: Apache-2.0
"""
Save pre-trained model.
"""
importtensorflowastf
importnumpyasnp
# pylint: disable=redefined-outer-name,reimported,import-outside-toplevel
defsave_pretrained_model(sess, outputs, feeds, out_dir, model_name="pretrained"):
"""Save pretrained model and config"""
try:
importos
importsys
importtensorflowastf
importsubprocess
to_onnx_path="{}/to_onnx".format(out_dir)
ifnotos.path.isdir(to_onnx_path):
os.makedirs(to_onnx_path)
saved_model="{}/saved_model".format(to_onnx_path)
inputs_path="{}/inputs.npy".format(to_onnx_path)
pretrained_model_yaml_path="{}/pretrained.yaml".format(to_onnx_path)
envars_path="{}/environment.txt".format(to_onnx_path)
pip_requirement_path="{}/requirements.txt".format(to_onnx_path)
print("===============Save Saved Model========================")
ifos.path.exists(saved_model):
print("{} already exists, SKIP".format(saved_model))
return
print("Save tf version, python version and installed packages")
tf_version=tf.__version__
py_version=sys.version
pip_packages=subprocess.check_output([sys.executable, "-m", "pip", "freeze", "--all"])
pip_packages=pip_packages.decode("UTF-8")
withopen(envars_path, "w") asfp:
fp.write(tf_version+os.linesep)
fp.write(py_version)
withopen(pip_requirement_path, "w") asfp:
fp.write(pip_packages)
print("Save model for tf2onnx: {}".format(to_onnx_path))
# save inputs
inputs= {}
forinp, valueinfeeds.items():
ifisinstance(inp, str):
inputs[inp] =value
else:
inputs[inp.name] =value
np.save(inputs_path, inputs)
print("Saved inputs to {}".format(inputs_path))
# save graph and weights
fromtensorflow.saved_modelimportsimple_save
# pylint: disable=unnecessary-comprehension
simple_save(sess, saved_model,
{n: iforn, iinzip(inputs.keys(), feeds.keys())},
{op.name: opforopinoutputs})
print("Saved model to {}".format(saved_model))
# generate config
pretrained_model_yaml='''
{}:
model: ./saved_model
model_type: saved_model
input_get: get_ramp
'''.format(model_name)
pretrained_model_yaml+=" inputs:\n"
forinp, _ininputs.items():
pretrained_model_yaml+= \
" \"{input}\": np.array(np.load(\"./inputs.npy\")[()][\"{input}\"])\n".format(input=inp)
outputs= [op.nameforopinoutputs]
pretrained_model_yaml+=" outputs:\n"
foroutinoutputs:
pretrained_model_yaml+=" - {}\n".format(out)
withopen(pretrained_model_yaml_path, "w") asf:
f.write(pretrained_model_yaml)
print("Saved pretrained model yaml to {}".format(pretrained_model_yaml_path))
print("=========================================================")
exceptExceptionasex: # pylint: disable=broad-except
print("Error: {}".format(ex))
deftest():
"""Test sample."""
x_val=np.random.rand(5, 20).astype(np.float32)
y_val=np.random.rand(20, 10).astype(np.float32)
x=tf.placeholder(tf.float32, x_val.shape, name="x")
y=tf.placeholder(tf.float32, y_val.shape, name="y")
z=tf.matmul(x, y)
w=tf.get_variable("weight", [5, 10], dtype=tf.float32)
init=tf.global_variables_initializer()
outputs= [z+w]
feeds= {x: x_val, y: y_val}
withtf.Session() assess:
sess.run(init)
sess.run(outputs, feeds)
# NOTE: NOT override the saved model, so put below snippet after testing the BEST model.
# if you perform testing several times.
save_pretrained_model(sess, outputs, feeds, "./tests", model_name="test")
if__name__=="__main__":
test()