- Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathclient.py
executable file
·119 lines (91 loc) · 4.19 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
""" Client for TensorFlow serving.
Reads titles from STDIN, and writes comment samples to STDOUT.
"""
# adapted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/examples/serving/ende_client.py
importsys
importargparse
importtensorflowastf
importsubprocess
importgrpc
importsubword_nmt.apply_bpeasapply_bpe
fromtensorflow_serving.apisimportpredict_pb2, prediction_service_pb2_grpc
classGenerator:
def__init__(self,
host,
port,
model_name,
preprocessor,
postprocessor,
bpe_codes):
channel=grpc.insecure_channel("%s:%d"% (host, port))
self.stub=prediction_service_pb2_grpc.PredictionServiceStub(channel)
self.model_name=model_name
self.preprocessor=preprocessor
self.postprocessor=postprocessor
withopen(bpe_codes) asf:
self.bpe=apply_bpe.BPE(f)
def__call__(self, title, n=8, timeout=50.0):
ifself.preprocessor:
# FIXME: Tried to reuse the process, but something seems to be buffering
preprocessor=subprocess.Popen([self.preprocessor],
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL)
title_pp=preprocessor.communicate((title.strip() +'\n').encode())[0].decode('utf-8')
else:
title_pp=title
title_bpe=self.bpe.segment_tokens(title_pp.strip().lower().split(' '))
#print(title_bpe)
request=predict_pb2.PredictRequest()
request.model_spec.name=self.model_name
request.inputs['tokens'].CopyFrom(
tf.make_tensor_proto([title_bpe] *n, shape=(n, len(title_bpe))))
request.inputs['length'].CopyFrom(
tf.make_tensor_proto([len(title_bpe)] *n, shape=(n,)))
future=self.stub.Predict.future(request, timeout)
result=future.result()
batch_predictions=tf.make_ndarray(result.outputs["tokens"])
batch_lengths=tf.make_ndarray(result.outputs["length"])
batch_scores=tf.make_ndarray(result.outputs["log_probs"])
hyps= []
for (predictions, lengths, scores) inzip(batch_predictions, batch_lengths, batch_scores):
# ignore </s>
prediction=predictions[0][:lengths[0]-1]
comment=' '.join([token.decode('utf-8') fortokeninprediction])
comment=comment.replace('@@ ', '')
#comment = comment.replace('<NL>', '\n')
# FIXME: Tried to reuse the process, but something seems to be buffering
ifself.postprocessor:
postprocessor=subprocess.Popen([self.postprocessor],
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL)
prediction_ready=postprocessor.communicate((comment+'\n').encode())[0].decode('utf-8')
else:
prediction_ready=comment
hyps.append((prediction_ready.strip(), float(scores[0])))
#hyps.sort(key=lambda hyp: hyp[1] / len(hyp), reverse=True)
returnhyps
defmain():
parser=argparse.ArgumentParser(description=__doc__)
parser.add_argument('--host', default='localhost', help='model server host')
parser.add_argument('--port', type=int, default=9000, help='model server port')
parser.add_argument('--model_name', required=True, help='model name')
parser.add_argument('--preprocessor', help='tokenization script')
parser.add_argument('--postprocessor', help='postprocessing script')
parser.add_argument('--bpe_codes', required=True, help='BPE codes')
parser.add_argument('-n', type=int, default=5, help='Number of comments to sample per title')
args=parser.parse_args()
generator=Generator(host=args.host,
port=args.port,
model_name=args.model_name,
preprocessor=args.preprocessor,
postprocessor=args.postprocessor,
bpe_codes=args.bpe_codes)
fortitleinsys.stdin:
hyps=generator(title, args.n)
forprediction, scoreinhyps:
sys.stdout.write('{}\t{}\n'.format(title.strip(), prediction))
if__name__=="__main__":
main()