如何用flask部署pytorch模型

如何用flask部署pytorch模型

來自專欄 深度煉丹

隨著深度學習越來越火,各種框架也層出不窮,如何訓練一個深度學習模型變得越來越簡單。然而在實際的工業場景中,我們往往更加關注如何部署一個已經訓練好的模型。

在這一點上,tensorflow做得非常好,提供了tensorflow serving來幫助我們非常方便地部署到工業場景下。眾所周知,PyTorch的一個非常大的劣勢就是沒有辦法很方便地部署模型,facebook和Microsoft一起搞了一個神經交換機,ONNX,可以將pytorch model轉換到Caffe2 model,這樣一是麻煩,二是Caffe2目前還在測試,一堆bug,用的人也不多,三是還要多學一個框架Caffe2。所以這並不是一個非常好的選擇。

目前的最新消息,Caffe2的源碼已經併到pytorch中了,或許這是Facebook準備對付TensorFlow的大招,我們拭目以待。

本文受到keras的一篇博文的啟發,會教大家如何使用flask來部署訓練好的pytorch模型。首先聲明一下,flask我也不太會用,因為看到了keras的文章,希望分享這種思路和想法,使用一種web框架實現深度學習模型的部署。

環境配置

首先確保安裝了pytorch,因為需要使用flask這個web框架,所以當然需要安裝flask,非常簡單,使用下面的命令進行安裝。

pip install flask

配置REST API

我們知道每次啟動模型,load參數是一件非常費時間的事情,而每次做前向傳播的時候模型其實都是一樣的,所以我們最好的辦法就是load一次模型,然後做完前向傳播之後仍然保留這個load好的模型,下一次有新的數據進來,我們就可以不用重新load模型,可以直接做前向傳播得到結果,這樣無疑節約了很多load模型的時間。所以我們需要建立一個類似於伺服器的機制,將模型在伺服器上load好,方便我們不斷去調用模型做前向傳播,那麼怎麼能夠達到這個目的呢?我們可以使用flask來建立一個REST API來達到這一目的。

REST API 是什麼呢?REST 是Representational State Transfer的縮寫,這是一種架構風格,這裡就不再過多描述,感興趣的同學可以自己去google一下。

那麼如何用flask啟動這個服務呢?

載入模型

app = flask.Flask(__name__)model = Noneuse_gpu = Truedef load_model(): """Load the pre-trained model, you can use your model just as easily. """ global model model = resnet50(pretrained=True) model.eval() if use_gpu: model.cuda()

首先我們需要使用上面的代碼來載入模型,前面三句話的非常簡單,第一句話表示調用flask初始化一個app,接著定義一個變數model來表示模型,use_gpu表示是否使用gpu。

接著定義 load_model 這個函數,在函數中將模型的參數load到前面定義的model中,這裡使用了resnet50,記得使用 .eval將model轉換成eval模式,如果要使用GPU,則加上 .cuda()

數據預處理

def prepare_image(image, target_size): """Do image preprocessing before prediction on any data. :param image: original image :param target_size: target image size :return: preprocessed image """ if image.mode != RGB: image = image.convert("RGB") # Resize the input image nad preprocess it. image = T.Resize(target_size)(image) image = T.ToTensor()(image) # Convert to Torch.Tensor and normalize. image = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) # Add batch_size axis. image = image[None] if use_gpu: image = image.cuda() return torch.autograd.Variable(image, volatile=True)

這裡就是 pytorch 中標準的預處理流程,首先將圖片resize到固定的大小,然後轉換成 tensor,接著做標準化。

啟動REST API

定義好了模型和數據預處理,接下來我們就需要開始啟動 flask 服務了。

@app.route("/predict", methods=["POST"])def predict(): # Initialize the data dictionary that will be returned from the view. data = {"success": False} # Ensure an image was properly uploaded to our endpoint. if flask.request.method == POST: if flask.request.files.get("image"): # Read the image in PIL format image = flask.request.files["image"].read() image = Image.open(io.BytesIO(image)) # Preprocess the image and prepare it for classification. image = prepare_image(image, target_size=(224, 224)) # Classify the input image and then initialize the list of predictions to return to the client. preds = F.softmax(model(image), dim=1) results = torch.topk(preds.cpu().data, k=3, dim=1) data[predictions] = list() # Loop over the results and add them to the list of returned predictions for prob, label in zip(results[0][0], results[1][0]): label_name = idx2label[label] r = {"label": label_name, "probability": float(prob)} data[predictions].append(r) # Indicate that the request was a success. data["success"] = True # Return the data dictionary as a JSON response. return flask.jsonify(data)

首先定義請求方式為 POST,表示向伺服器傳輸數據,接著定義一個 predict 函數來進行模型的前向傳播。

prediect 中,首先建立一個字典 data 來存儲請求狀態,初始化為 false。接著通過 flask.request.method 來判斷是否是 POST 請求,如果是的話,我們就通過 flask.request.files.get("image") 來判斷是否能夠得到從遠端傳過來的數據,如果確實有數據傳過來,我們就可以通過 flask.request.files["image"].read() 來得到從遠方 POST 上來的數據。

為了傳輸的速度考慮,一般都會傳二進位的文件,所以通過 io.BytesIO(image) 將二進位的文件讀取出來,再通過 PIL.Image.open 來讀取這個圖片,這樣我們就解碼了一張從遠端傳過來的圖片了。

然後下面的操作就很簡單了,首先通過 prepare_image 將圖片做預處理,接著傳入到網路當中,這裡需要注意我們會使用 F.softmax 將模型的輸出得分轉換成一個概率分布,因為我們想要輸出 top3 的結果和置信概率。最後我們就是將結果存到 data 中,返回成 json 的文件。

最後我們在 main 函數中調用

load_model()app.run()

就可以啟動 flask 服務了。通過上面的代碼,我們知道了如何處理傳過來的圖片並輸出預測的結果,那麼我們如何傳圖片呢?這就是下面會講的如何發起數據請求。

發送數據請求

發送數據請求並不難,首先我們需要知道上面定義好的 flask server 的地址,因為這就是我們在本地定義的,所以地址是

PyTorch_REST_API_URL = http://127.0.0.1:5000/predict

上面的 /predict 是因為我們前面使用了 @app.route("/predict", methods=["POST"])

接著我們定義一個函數來發送數據請求

def predict_result(image_path): # Initialize image path image = open(image_path, rb).read() payload = {image: image} # Submit the request. r = requests.post(PyTorch_REST_API_URL, files=payload).json() # Ensure the request was successful. if r[success]: # Loop over the predictions and display them. for (i, result) in enumerate(r[predictions]): print({}. {}: {:.4f}.format(i + 1, result[label], result[probability])) # Otherwise, the request failed. else: print(Request failed)

傳入的參數 image_path 是圖片路徑,然後使用 requests.post(PyTorch_REST_API_URL, files=payload).json() 向伺服器傳入數據,同時得到伺服器計算的結果,最後將結果 print 出來就可以了。

實驗結果

我們使用 ResNet50 作為預訓練的模型,傳入下面這張圖片作為測試

首先在一個終端中運行

python run_pytorch_server.py

來啟動 flask server,等待一會兒,可以得到下面的結果

然後我們重新打開一個新的終端,運行下面的代碼

python simple_request.py --file=./dog.jpg

這裡的 dog.jpg 可以改成你自己的文件路徑,然後我們可以得到下面的結果

討論

最後我們實現了一個簡單的深度學習伺服器,當然這個模型是在本地建立的,我們當然可以將模型建立到遠端的伺服器上,本地向遠端發送請求。當然這只是一個 toy model,我們可以基於這種思想設計更加複雜的結構。


本文內容參考自 Building a simple Keras + deep learning REST API

本文的完整code

歡迎關注我的知乎專欄深度煉丹

歡迎訪問我的博客


推薦閱讀:

Rethinking ICCV 2017 [Part 1]
人臉識別演算法演化史
Rocket Training: 一種提升輕量網路性能的訓練方法
基於深度學習的計算機視覺應用之目標檢測
深入理解GoogLeNet結構(原創)

TAG:PyTorch | 深度學習DeepLearning | 計算機視覺 |