筆記 - TensorFlow模型的跨平台部署(四)
來自專欄 Cerulean的專欄4 人贊了文章
本文接上文:
Cerulean:筆記 - TensorFlow模型的跨平台部署(三)
方案三:使用TensorFlow Serving
TensorFlow Serving是一個專門用於將TensorFlow模型部署於生產環境一個工具模塊,通過TensorFlow Serving,我們可以輕易地部署TensorFlow模型到生產環境。在介紹TensorFlow Serving的核心概念和使用方法前,請先安裝依賴:
sudo apt-get update && sudo apt-get install -y automake build-essential curl libcurl3-dev git libtool libfreetype6-dev libpng12-dev libzmq3-dev pkg-config python-dev python-numpy python-pip software-properties-common swig zip zlib1g-dev
以上的依賴是對應於Ubuntu 16.04版本,如果是其他發行版本的Linux,請自行解決依賴問題。在確認依賴安裝後,請通過以下命令安裝TensorFlow Serving的客戶端Python API工具包:
pip install tensorflow-serving-api
然後,通過以下命令安裝TensorFlow Serving的伺服器包:
echo「deb [arch = amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal」| sudo tee /etc/apt/sources.list.d/tensorflow-serving.listcurl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add - sudo apt-get update && sudo apt-get install tensorflow-model-server
Servables
Servables是TensorFlow Serving的一個核心概念,我們可以直觀的理解為一個服務,這個服務運行在某個埠,監聽某個埠並處理請求,這個請求可以是gRPC協議的,也可以是HTTP的,請求處理完成後返回結果,直覺地,我們將某個模型的輸入封裝成gRPC或者HTTP的request,然後POST給Servables,然後Servables作為模型的抽象,接受輸入並計算輸出,將結果封裝成gRPC或者HTTP的response,返回給客戶端,這個過程與Web服務非常類似。本質上,Servables所作的核心工作,即是自動地讀取我們已經持久化的模型,並通過代碼模板自動地以C++版本恢復模型,並實現了一個靈活的服務,這個服務支持一個或者多個協議的請求,監聽某個埠,專註地處理request並返回response。一個典型的Servables的核心即是一個SavedModelBundle
實例,關於SavedModelBundle
類,我們在介紹方案一時,在用C++恢復模型時簡要的展示了這個類的用法:
SessionOptions sessionOptions;RunOptions runOptions;SavedModelBundle bundle;Status status;status = LoadSavedModel(sessionOptions, runOptions, GraphDir, {kSavedModelTagServe}, &bundle);bundle.session->Run({{xInputTensor.name(), input}}, {yPredictTensor.name()}, {}, &yPredict);
可以看到,SavedModelBundle
的實例有一個session
成員變數,這個session
與Python版本的Session
的功能是基本類似的。另一方面,我們在介紹方案一時曾經提到,Python版本的TensorFlow的SavedModelBuilder
在調用add_meta_graph_and_variables
時有一個小問題,即對於每一個SavedModelBuilder
實例來說,你只能調用一次add_meta_graph_and_variables
方法,如果你試圖調用兩次,那麼你會捕獲到一個異常,提示meta_graph
和variables
已經添加,不可以重複添加,除此之外,在SavedModelBuilder
的某個實例多次調用builder.save()
方法時,你仍會捕獲一個異常,這個異常提示持久化目錄不為空,如果你執意要持久化此模型,那麼你必須先清空當前持久化目錄才可以調用builder.save()
方法。回顧這兩個問題,是為了引出一個版本控制話題。
Servable Versions
事實上,Servables提供了一些機制支持不同版本的模型,在後文我們會看到,我們會通過不同的URI,例如不同的協議,不同的版本號:
http://host:port/v1/models/test/versions/${MODEL_VERSION}:predict
可以出發Servables調度不同版本的模型處理不同URI的request,即我們可以將不同的版本的模型按約定的版本號建立不同的文件夾存儲在磁碟,通過TensorFlow Serving,我們可以非常自然地將他們組合起來使用而使得彼此互不影響,這可能非常適用於某一些生產環境的部署,但是目前就筆者的問題規模,暫時沒有用到這個特性。
Loader and Source and Manager
事實上,Loader和Sources支持著不同版本Servables的實例化與調度,即Source將會根據特定的模型版本創建Loader,而Loader將會實例化Servables,這些Servables又統一受Manager管理,Manager更像一個反向代理的角色,它直接處理request,並分發給特定的Servables處理,並返回response。
構建計算圖並持久化模型
首先,仍然以方案一與方案二的常式為例,創建一個持久化模型,假設我們的網路具有如下結構:
session = tf.Session()x_input = tf.placeholder(tf.float32, [None, 1])y_input = tf.placeholder(tf.float32, [None, 1])fc1 = tf.layers.dense(x_input, 10, tf.nn.relu)fc2 = tf.layers.dense(fc1, 10, tf.nn.relu)y_predict = tf.layers.dense(fc2, 1)loss_func = tf.losses.mean_squared_error(labels=y_input, predictions=y_predict)optimizer = tf.train.AdamOptimizer().minimize(loss_func)session.run(tf.global_variables_initializer())
那我們首先需要構造signature_def
對象,我們可以通過以下代碼構造signature_def
對象:
signature = tf.saved_model.signature_def_utils.build_signature_def( inputs={ x_input: tf.saved_model.utils.build_tensor_info(x_input), y_input: tf.saved_model.utils.build_tensor_info(y_input) }, outputs={ y_predict: tf.saved_model.utils.build_tensor_info(y_predict), loss_func: tf.saved_model.utils.build_tensor_info(loss_func) }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
正如方案一中介紹的,我們這裡簡要的做一下回顧,API接收三個重要參數,其中inputs
即是tf.placeholder
,用於描述輸入張量,outputs
即是輸出張量。在這裡,可以看到我們分別用字元串x_input
與y_input
,y_predict
與loss_func
作為序列化後獲取張量的鍵,而tf.saved_model.utils.build_tensor_info
就是將張量轉為protobuf結構的快捷方法。然後訓練模型並持久化模型:
for step in range(2000): session.run(optimizer, { x_input: x_train, y_input: y_train }) if (step + 1) % 500 == 0: if os.path.exists(graph_save_dir): shutil.rmtree(graph_save_dir) builder = tf.saved_model.builder.SavedModelBuilder(graph_save_dir) builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING], {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}) builder.save()
我們構造SavedModelBuilder
對象,並調用add_meta_graph_and_variables
添加meta_graph
與variables
,並調用多次builder.save()
方法持久化模型,其中[tf.saved_model.tag_constants.SERVING]
是持久化模型的標籤,我們後續可以通過命令行檢查哪些標籤有對應有效的持久化模型,{tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
鍵值對指定了對於SERVING
標籤下signature
鍵對應的計算圖,我們在方案二中提到過,持久化模型可以對應多個計算圖。
檢查持久化模型
持久化模型完成後,我們可以通過方案二介紹的命令行工具saved_model_cli
檢查持久化模型對應Tags的signature_def
信息::
saved_model_cli show --dir=. --tag_set=serve --signatue_def=serving_defaultThe given SavedModel SignatureDef contains the following input(s): inputs[x_input] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: Placeholder:0 inputs[y_input] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: Placeholder_1:0The given SavedModel SignatureDef contains the following output(s): outputs[loss_func] tensor_info: dtype: DT_FLOAT shape: () name: mean_squared_error/value:0 outputs[y_predict] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: dense_2/BiasAdd:0Method name is: tensorflow/serving/predict
可以看到,如果以上所有操作無誤,應該可以看到saved_model_cli
列出的模型信息應該與我們構造signature_def
時一致。
使用TensorFlow Serving API啟動Serving服務
接下來,我們將通過TensorFlow Serving API啟動Serving服務:
tensorflow_model_server --port=9000 --rest_api_port=9001 --model_name=test --model_base_path=$(pwd)2018-07-24 20:11:09.772135: I tensorflow_serving/model_servers/main.cc:153] Building single TensorFlow model file config: model_name: test model_base_path: /home/keyunlong/pycharm_project_775/playground/graph2018-07-24 20:11:09.772452: I tensorflow_serving/model_servers/server_core.cc:459] Adding/updating models.2018-07-24 20:11:09.772479: I tensorflow_serving/model_servers/server_core.cc:514] (Re-)adding model: test2018-07-24 20:11:10.179066: I tensorflow_serving/core/basic_manager.cc:716] Successfully reserved resources to load servable {name: test version: 1}2018-07-24 20:11:10.179115: I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: test version: 1}2018-07-24 20:11:10.179153: I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: test version: 1}2018-07-24 20:11:10.179224: I external/org_tensorflow/tensorflow/contrib/session_bundle/bundle_shim.cc:360] Attempting to load native SavedModelBundle in bundle-shim from: /home/keyunlong/pycharm_project_775/playground/graph/12018-07-24 20:11:10.179279: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:242] Loading SavedModel with tags: { serve }; from: /home/keyunlong/pycharm_project_775/playground/graph/12018-07-24 20:11:10.199380: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA2018-07-24 20:11:10.331559: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:161] Restoring SavedModel bundle.2018-07-24 20:11:10.464002: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:196] Running LegacyInitOp on SavedModel bundle.2018-07-24 20:11:10.466276: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:291] SavedModel load for tags { serve }; Status: success. Took 286961 microseconds.2018-07-24 20:11:10.466395: I tensorflow_serving/servables/tensorflow/saved_model_warmup.cc:83] No warmup data file found at /home/keyunlong/pycharm_project_775/playground/graph/1/assets.extra/tf_serving_warmup_requests2018-07-24 20:11:10.467172: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: test version: 1}2018-07-24 20:11:10.653258: I tensorflow_serving/model_servers/main.cc:323] Running ModelServer at 0.0.0.0:9000 ...2018-07-24 20:11:10.713689: I tensorflow_serving/model_servers/main.cc:333] Exporting HTTP/REST API at:localhost:9001 ...[evhttp_server.cc : 235] RAW: Entering the event loop ...
可以看到,我們通過tensorflow_model_server
命令行工具,指定監聽gRPC和HTTP的埠,指定模型名,指定模型基目錄(支持多版本),就可以啟動一個Serving服務了。通過Serving的啟動日誌我們可以發現,Serving成功的找到了版本號為1的模型,並通過SavedModelBundle
類恢復模型,並通過Loader創建了一個Servable,然後開始監聽9000和9001埠,進入事件循環,等待請求。
通過gRPC調用Serving服務
這裡僅僅展示gRPC的小程序,我們編寫如下代碼通過gRPC協議調用Serving服務:
# Init channel.channel = implementations.insecure_channel(localhost, 9000)# Init stub.stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)# Init request.request = predict_pb2.PredictRequest()request.model_spec.name = testrequest.model_spec.signature_name = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEYrequest.inputs[x_input].CopyFrom( tf.contrib.util.make_tensor_proto(x_train, shape=x_train.shape))request.inputs[y_input].CopyFrom( tf.contrib.util.make_tensor_proto(y_train, shape=y_train.shape))# Predict.future = stub.Predict.future(request, 2.0)result = future.result().outputs[loss_func].float_vallogging.warning(Loss: {}.format(result))
如果一切正常,你會看到以下數值作為輸出,會因訓練結果而異:
Connected to pydev debugger (build 181.4203.547)WARNING:root:Loss: [0.12157168239355087]
快速地解釋一下這個小程序,我們構造了一個gRPC協議的請求,主要包含了模型名test
,signature_name
的鍵,即指定使用哪個計算圖,構造了x_input
和y_input
張量,並封裝成protobuf,然後通過實例化的程序樁調用Predict發送request,然後輸出結果
推薦閱讀:
※數據挖掘:從學術界到業界
※404錯誤是個啥?
※Google 面試學習手冊,來看看谷歌,微軟等大廠都面試什麼
※站長之家論壇(bbs.chinaz.com)宣布關站
※阿拉丁飛毯不是童話?未來世界竟然是這樣的!
TAG:TensorFlow | 計算機科學 | 科技 |