背着锄头的互联网农民

0%

Java调用TensorFlow-Serving

TensorFlow-Serving入门中,我们学习了如何搭建TensorFlow-Serving的Docker服务,并启动Docker。本文将尝试在Java中调用TensorFlow-Serving服务,支持两种方式:HTTP和GRPC。

启动TensorFlow-Serving服务
1
nohup sudo docker run -p 8502:8500 -p 8501:8501 --name tfserving_testnet  --mount type=bind,source=/home/tensorflow/xception,target=/models/xception  -e MODEL_NAME=xception -t tensorflow/serving &
Java Http客户端

Java调用TensorFlow库预测图片质量中,我们使用TensorFlow对图片进行预处理转换成Tensor,然后再输入到模型中进行预测。本文默认已经对图片进行了预处理生成了Tensor结构,http client代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
float result = -1f;
Tensor<Float> input = null;
try {
input = getImageTensor1(imageData);//30ms,图片预处理成Tensor结构
if (input == null) {
return result;
}

float[][][][] imgTensor = new float[1][IMAGE_HEIGTH][IMAGE_WIDTH][3];
input.copyTo(imgTensor); //120ms左右
String request = "{\"instances\":" + Arrays.deepToString(imgTensor) + "}";//耗时120ms左右

//http访问TensorFlow耗时240ms左右
return Metrics.timer("image_recognize_predict_latency2").record(() -> {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.valueOf("application/json;UTF-8"));
HttpEntity<String> strEntity = new HttpEntity<String>(request, headers);
RestTemplate restTemplate = new RestTemplate();
JSONObject result1 = restTemplate.postForObject("http://localhost:8501/v1/models/xception:predict", strEntity, JSONObject.class);
float result2 = result1.getJSONArray("predictions").getJSONArray(0).getFloat(0);
return result2;
});
} catch(Exception e){
LOGGER.warn("image recognize failed. {}", e);
} finally {
if (input != null) {
input.close();
}
}

return result;
其中copyTo和deepToString函数比较耗时,http访问TensorFlow-Serving耗时也比较多,猜测是序列化的问题(GRPC访问时候性能明显改善)。

Java Grpc客户端

Grpc客户端的大致逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public class TensorflowGrpcClient {
private final Logger LOGGER = LoggerFactory.getLogger(TensorflowGrpcClient.class);
private final int IMAGE_WIDTH = 480;
private final int IMAGE_HEIGTH = 480;

private PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub;
private final ManagedChannel channel;
private static TensorflowGrpcClient _userInfoGrpcClient;

TensorflowGrpcClient() {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forAddress("127.0.0.1", 8502).usePlaintext();
channel = channelBuilder.build();
blockingStub = PredictionServiceGrpc.newBlockingStub(channel);
}

public static synchronized TensorflowGrpcClient getInstance() {
if (_userInfoGrpcClient == null) {
_userInfoGrpcClient = new TensorflowGrpcClient();
}
return _userInfoGrpcClient;
}

public float predict(byte[] buffer) {
TensorProto.Builder tensorProto = TensorProto.newBuilder();
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(IMAGE_WIDTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(IMAGE_HEIGTH));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProto.setDtype(DataType.DT_FLOAT);
tensorProto.setTensorShape(tensorShapeBuilder.build());
ByteString content = ByteString.copyFrom(buffer);
tensorProto.setTensorContent(content);
/*
//这种方法大概耗时10ms左右,比setTensorContent方法耗时长
for (int i = 0; i < IMAGE_WIDTH; i++ ) {
for ( int j = 0; j < IMAGE_HEIGTH; j++ ) {
for ( int k = 0; k < 3; k++ ) {
tensorProto.addFloatVal(imgTensor[0][i][j][k]);
}
}
}
*/

Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder()
.setName("xception")
.setSignatureName("serving_default")
.build();
Predict.PredictRequest request = Predict.PredictRequest.newBuilder()
.setModelSpec(modelSpec)
.putInputs("in", tensorProto.build())
.build();

return Metrics.timer("image_recognize_predict_latency3").record(() -> {
Predict.PredictResponse predictResponse = blockingStub.predict(request);//访问TensorFlow-Serving服务
LOGGER.info("predictResponse is {}", predictResponse.toString());
return predictResponse.getOutputsMap().get("out").getFloatVal(0);
});
}
}

Java主程序中调用GrpcClient的predict函数进行预测图片质量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
float result = -1f;
Tensor<Float> input = null;
try {
input = getImageTensor1(imageData); //30ms,图片预处理成Tensor
if (input == null) {
return result;
}

//float[][][][] imgTensor = new float[1][IMAGE_HEIGTH][IMAGE_WIDTH][3];
//input.copyTo(imgTensor); //120ms
ByteBuffer buffer = ByteBuffer.allocate(IMAGE_HEIGTH*IMAGE_WIDTH*3*4);
input.writeTo(buffer);

//grpc访问TensorFlow-Serving服务耗时130ms左右
return Metrics.timer("image_recognize_predict_latency2").record(() -> {
return TensorflowGrpcClient.getInstance().predict(buffer.array()); //150ms
});
} catch(Exception e){
LOGGER.warn("image recognize failed. {}", e);
} finally {
if (input != null) {
input.close();
}
}

return result;

总结

使用Http Client访问TensorFlow-Serving服务总耗时在500ms左右,主要是数据序列化和反序列化的耗时。使用GRPC Client访问TensorFlow-Serving服务总耗时160ms左右(和Java直接调用TensorFlow库耗时相当),明显比Http访问性能更好。