筆記 - TensorFlow模型的跨平台部署(三)
來自專欄 Cerulean的專欄9 人贊了文章
本文接上文:
Cerulean:筆記 - TensorFlow模型的跨平台部署(二)
方案二:使用saved_model_cli
saved_model_cli
提供了一種通過命令行檢查並恢復模型的機制,如果你的TensorFlow是通過pip安裝的,那麼saved_model_cli
應該已經被一同安裝,saved_model_cli
主要有兩個命令,一個是show
,一個是run
,假設我們已經按照方案一提到的saved_model_builder
方式在某個路徑有了持久化的模型,我們可以通過如下方式檢查該模型的相關信息:
saved_model_cli show --dir=graph_dirThe given SavedModel contains the following tag-sets:train
通過show
命令,給出持久化模型的路徑,saved_model_cli
返回有效的meta_graph_def
值對應的鍵,這個鍵即是我們在:
builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING], {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
即在持久化模型前為當前元計算圖和變數信息定義的Tags,事實上我們可以根據需要在持久化時定義多個Tags,但是我目前沒有用到這個功能。我們可以通過以下命令檢查持久化模型對應Tags的signature_def
信息:
saved_model_cli show --dir=. --tag_set=serve --signatue_def=serving_defaultThe given SavedModel SignatureDef contains the following input(s): inputs[x_input] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: Placeholder:0 inputs[y_input] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: Placeholder_1:0The given SavedModel SignatureDef contains the following output(s): outputs[loss_func] tensor_info: dtype: DT_FLOAT shape: () name: mean_squared_error/value:0 outputs[y_predict] tensor_info: dtype: DT_FLOAT shape: (-1, 1) name: dense_2/BiasAdd:0Method name is: tensorflow/serving/predict
可以看到,我們在通過:
signature = tf.saved_model.signature_def_utils.build_signature_def( inputs={ x_input: tf.saved_model.utils.build_tensor_info(x_input), y_input: tf.saved_model.utils.build_tensor_info(y_input) }, outputs={ y_predict: tf.saved_model.utils.build_tensor_info(y_predict), loss_func: tf.saved_model.utils.build_tensor_info(loss_func) }, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
這種方式定義的signature_def
信息可以通過上述命令檢出,這裡回顧以下signature_def
,我們構造了signature_def
對象,這個對象包含了計算圖中輸入與輸出張量的鍵值對信息,鍵即是張量名,值即是protobuff結構的張量,用一個method_name
鍵來描述功能。可以看到,在通過save_model_cli
的show
命令檢出的signature_def
信息中,我們定義的關鍵張量的鍵值對,都被成功檢出。接著我們可以通過run
命令恢復模型,並根據需要進行計算:
saved_model_cli run --dir=./graph --tag_set=serve --signature_def=serving_default --inputs "x_input=./x_train.npy;y_input=./y_train.npy"2018-07-21 16:23:38.098684: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.2 AVXResult for output key loss_func:0.0157029Result for output key y_predict:[[ 0.94755262] [ 0.92479306]]
這裡有兩個大坑,需要格外注意:
- Python3可以運行
run
命令,但是無法通過給定--inputs以*.npy
文件的形式計算結果,如果你這麼做,你會捕獲到一個異常,大意是提示存在編解碼的錯誤,簡單谷歌后無法解決,遂擱置。 - Python2.7可以運行
run
命令,但是inputs後面的參數形式必須要用雙引號""包含,如果你這麼做,那麼會導致第一個之後的輸入無法被解析。
針對第二個坑,官網常式中給出的例子並沒有用雙引號包含:
$ saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve --signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npz[x];x2=/tmp/my_data2.pkl --outdir /tmp/out --overwriteResult for output key y:[[ 1.5] [ 2.5] [ 3.5]]
起初我以為是shell的問題,因為筆者的shell是zsh,但是切換回bash已然會報錯,這點請務必注意。
明天整理一下Serving相關的內容。
推薦閱讀:
※探秘「棧」之旅
※Mozilla Firefox Ltd面經分享:真題+參考答案匯總
※PDF加密工具及方法
※在GitHub上搭建自己的個人主頁
※從mSATA到M.2,新生代固態硬碟介面優勢解讀
TAG:科技 | TensorFlow | 計算機科學 |