背着锄头的互联网农民

0%

使用Grpc调用TensorFlow-Serving服务

上一篇文章中,我们学习了在Mac机器上搭建TensorFlow-Serving的示例,并提供了使用Http接口访问的方法。本文中,我们将尝试使用Grpc接口访问TensorFlow-Serving服务。

启动服务
1
nohup sudo docker run -p 8502:8500 -p 8501:8501 --name tfserving_testnet  --mount type=bind,source=/home/tensorflow/xception,target=/models/xception  -e MODEL_NAME=xception -t tensorflow/serving &

其中,本机的8502端口对应Docker的8500端口(GRPC端口),本机8501端口对应Docker的8501端口(HTTP端口)。

模型信息

在写Gprc服务之前,需要明确模型的名字、输入、输出等。我们使用curl http://localhost:8501/v1/models/xception/metadata可以看到Docker中模型的基本信息。 其中,方框内的内容要在下面Client的代码中用到。

Python Grpc Client

下面代码是python的客户端代码,输入一张图片,输出模型对这张图片的打分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#encoding=utf8
import requests
import numpy as np
import tensorflow.compat.v1 as tf
import time
tf.disable_v2_behavior()
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=3)

from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from keras.applications import xception
from tensorflow.python.platform import gfile

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc

tf.app.flags.DEFINE_string('server', '127.0.0.1:8502', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS

def prediction():
images = image.load_img("test.jpg", target_size=(480, 480))
x = image.img_to_array(images)
x = np.expand_dims(x, axis=0)
image_np = xception.preprocess_input(x)

options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), ('grpc.max_receive_message_length', 1000 * 1024 * 1024)]
channel = grpc.insecure_channel(FLAGS.server, options = options)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'xception' #对应上图第一个方框
request.model_spec.signature_name = 'serving_default' #对应上图第二个方框
request.inputs['in'].CopyFrom(tf.make_tensor_proto(image_np)) #in对应上图第三个方框,为模型的输入Name

result_future = stub.Predict.future(request, 10.0) # 10 secs timeout
result = result_future.result()
print result

if __name__ == "__main__":
prediction()