一句代码发布你的TensorFlow模型,简明TensorFlowServing上手教程

专知 2018-12-04 12:58
关注文章

【导读】深度学习框架TensorFlow不仅在学术界得到了普及,在工业界也有非常广泛的应用。日常我们接触到的TensorFlow的用法大多为基于Python的实验用法,并不能直接用于工业界的线上产品。本文介绍一种简单的发布TensorFlow模型的方法。


在工业产品中使用TensorFlow模型的方法

在工业产品中TensorFlow大概有下面几种使用方法:

  1. 用TensorFlow的C++/Java/Nodejs API直接使用保存的TensorFlow模型:类似Caffe,适合做桌面软件。

  2. 直接将使用TensorFlow的Python代码放到Flask等Web程序中,提供Restful接口:实现和调试方便,但效率不太高,不大适合高负荷场景,且没有版本管理、模型热更新等功能。

  3. 将TensorFlow模型托管到TensorFlow Serving中,提供RPC或Restful服务:实现方便,高效,自带版本管理、模型热更新等,很适合大规模线上业务。

本文介绍的是方法3,如何用最简单的方法将TensorFlow发布到TensorFlow Serving中。


一句代码保存TensorFlow模型

# coding=utf-8
import tensorflow as tf

# 模型版本号
model_version = 1
# 定义模型
x = tf.placeholder(tf.float32, shape=[None, 4], name="x")
y = tf.layers.dense(x, 10, activation=tf.nn.softmax)

with tf.Session() as sess:
   # 初始化变量
   
sess.run(tf.global_variables_initializer())
   # 模型训练过程,省略
   
# ......
   
   #
保存训练好的模型到"model/版本号"
   
tf.saved_model.simple_save(
       session=sess,
       
export_dir="model/{}".format(model_version),
       
inputs={"x": x},
       
outputs={"y": y}
   )


代码中除了最后一句,其它部分都是常规的TensorFlow代码,模型定义、进入Session、模型训练等。代码的最后用tf.saved_model.simple_save将模型保存为SavedModel。注意,这里将模型保存在了"model/版本号"文件夹中,而不是直接保存在了"model"文件夹中,这是因为TensorFlow Serving要求在模型目录下加一层版本目录,来进行版本维护、热更新等:


安装TensorFlow Serving


方法一:用apt-get安装

对于Ubuntu或Debian(Bash on Windows10也可以),可以使用apt-get安装Tensorflow Serving。先用下面的命令添加软件源:

echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

添加成功后可直接用apt-get进行安装:

apt-get update && apt-get install tensorflow-model-server


方法二:用Docker安装

TensorFlow Serving官方提供了Docker容器,可以一键安装:

docker pull tensorflow/serving


将模型发布到TensorFlow Serving中

下面的方法基于在本机使用apt-get安装TensorFlow Serving的方法。对于Docker用户,需要将模型挂载或复制到Docker中,按照Docker中的路径来执行下面的教程。

用下面这行命令,就可以启动TensorFlow Serving,并将刚才保存的模型发布到TensorFlow Serving中。注意,这里的模型所在路径是刚才"model"目录的路径,而不是"model/版本号"目录的路径,因为TensorFlow Serving认为用户的模型所在路径中包含了多个版本的模型。

tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=模型名 --model_base_path=模型所在路径

客户端可以用GRPC和Restful两种方式来调用TensorFlow Serving,这里我们介绍基于Restful的方法,可以看到,命令中指定的Restful服务端口为8501,我们可以用curl命令来查看服务的状态:

curl http://localhost:8501/v1/models/model

执行结果:

{
"model_version_status": [
 {
  "version": "1",
 
"state": "AVAILABLE",
 
"status": {
   "error_code": "OK",
   
"error_message": ""
 
}
 }
]
}

下面我们用curl向TensorFlow Serving发送一个输入x=[1.1, 1.2, 0.8, 1.3],来获取预测的输出信息y:

curl -d '{"instances": [[1.1,1.2,0.8,1.3]]}' -X POST http://localhost:8501/v1/models/模型名:predict

服务器返回的结果如下:

{
   "predictions": [[0.0649088, 0.0974758, 0.0456831, 0.297224, 0.152209, 0.0177431, 0.104193, 0.0450511, 0.13074, 0.044771]]
}

我们的模型成功地输出了y=[0.0649088, 0.0974758, 0.0456831, 0.297224, 0.152209, 0.0177431, 0.104193, 0.0450511, 0.13074, 0.044771]


这里我们使用的是curl命令,在实际工程中,使用requests(Python)、OkHttp(Java)等Http请求库可以用类似的方法方便地请求TensorFlow Serving来获取模型的预测结果。


版本维护和模型热更新

刚才我们将模型保存在了"model/1"中,其中1是模型的版本号。如果我们的算法工程师研发出了更好的模型,此时我们并不需要将TensorFlow Serving重启,只需要将新模型发布在"model/新版本号"中,如"model/2"。TensorFlow Serving就会自动发布新版本的模型,客户端也可以请求新版本对应的API了。


转自专知公众号


微信扫一扫
关注该公众号

{{panelTitle}}
支持Markdown和数学公式,公式格式:\\(...\\)或\\[...\\]

还没有内容

关注微信公众号