概述
本文参考官方的java使用TensorFlow库的例子,将深度学习模型保存成pb文件,在java环境中加载模型并做预测。
环境安装
安装pip
1 | yum -y install epel-release |
安装TensorFlow、Keras、numpy
1 | pip install tensorflow //安装的是最新的tensorflow2.1版本 |
Maven配置
在pom.xml中增加如下配置,加载java的tensorflow库 1
2
3
4
5<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency>1
2
3
4InputStream inputStream = ImageRecognize.class.getResourceAsStream(MODEL_PATH);
Graph graph = new Graph();
graph.importGraphDef(IOUtils.toByteArray(inputStream));
Session session = new Session(graph);1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19BufferedImage bufferedImage = new BufferedImage(480, 480, BufferedImage.TYPE_INT_RGB);
Graphics graphics = bufferedImage.getGraphics();
InputStream in = new ByteArrayInputStream(imageData);
Image srcImage = ImageIO.read(in);
graphics.drawImage(srcImage, 0, 0, 480, 480, null); //将图片大小转换为480*480
int w = bufferedImage.getWidth();
int h = bufferedImage.getHeight();
float[][][][] imgTensor = new float[1][h][w][3];
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
int pixel = bufferedImage.getRGB(j, i); // 下面三行代码将一个数字转换为RGB数字,同时归一化到[-1,1]区间
imgTensor[0][i][j][0] = (float) ((pixel & 0xff0000) >> 16) / 127.5f - 1;
imgTensor[0][i][j][1] = (float) ((pixel & 0xff00) >> 8) / 127.5f - 1;
imgTensor[0][i][j][2] = (float) ((pixel & 0xff)) / 127.5f - 1;
}
}
return Tensors.create(imgTensor);
TensorFlow预处理
TensorFlow的预处理参考了LabelImage.java调用方式,它是使用TensorFlow Graph的一些预定义好的Operator来对图片做预处理。 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102private Tensor<Float> getImageTensor(byte[] imageBytes){
Graph g = new Graph();
GraphBuilder b = new GraphBuilder(g);
final int H = IMAGE_HEIGTH;
final int W = IMAGE_WIDTH;
final float mean = 1f;
final float scale = 127.5f;
final Output<String> input = b.constant("input", imageBytes);
final Output<Float> output =
b.sub(
b.div(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), Float.class), //解析jpeg文件
b.constant("make_batch", 0) //扩展成4维Tensor
),
b.constant("size", new int[]{H, W}) //resize图片成[H,W]大小
),
b.constant("scale", scale) //每个值除以127.5f
),
b.constant("mean", mean) //归一化到[-1,1]区间
);
try (Session s = new Session(g)) {
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
Output<Float> div(Output<Float> x, Output<Float> y) {
return binaryOp("Div", x, y);
}
<T> Output<T> sub(Output<T> x, Output<T> y) {
return binaryOp("Sub", x, y);
}
<T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
return binaryOp3("ResizeBilinear", images, size);
}
<T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
return binaryOp3("ExpandDims", input, dim);
}
<T, U> Output<U> cast(Output<T> value, Class<U> type) {
DataType dtype = DataType.fromClass(type);
return g.opBuilder("Cast", "Cast")
.addInput(value)
.setAttr("DstT", dtype)
.build()
.<U>output(0);
}
Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.<UInt8>output(0);
}
<T> Output<T> constant(String name, Object value, Class<T> type) {
try (Tensor<T> t = Tensor.<T>create(value, type)) {
return g.opBuilder("Const", name)
.setAttr("dtype", DataType.fromClass(type))
.setAttr("value", t)
.build()
.<T>output(0);
}
}
Output<String> constant(String name, byte[] value) {
return this.constant(name, value, String.class);
}
Output<Integer> constant(String name, int value) {
return this.constant(name, value, Integer.class);
}
Output<Integer> constant(String name, int[] value) {
return this.constant(name, value, Integer.class);
}
Output<Float> constant(String name, float value) {
return this.constant(name, value, Float.class);
}
private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
}
private Graph g;
}
模型预测
我们的xception模型中,输入节点的名字为input_1,输出节点的名字为output,对应着代码里的名字,需要完全一致。
float result = -1;
input = getImageTensor1(imageData);
if ( input == null ) {
return result;
}
List<Tensor<?>> results = session.runner().feed("input_1", input).fetch("output").run();
if (results.size() > 0 && results.get(0).shape().length == 2) {
long[] rshape = results.get(0).shape();
int rs = (int) rshape[0];
int rt = (int) rshape[1];
float realResult[][] = new float[rs][rt];
results.get(0).copyTo(realResult);
for (int i = 0; i < rs; i++) {
for (int j = 0; j < rt; j++) {
result = realResult[i][j];
break;
}
}
}
线上部署
线上使用时候,有一个线程不断的从HDFS中检查并读取最新的模型。一旦模型有更新,则加载新模型替换旧模型。