利用Flask搭建Pytorch深度學習服務
不像Tensorflow自帶了Tensorflow Server,Pytorch官方沒有自帶部署應用。因此如果想要把Pytorch用於生產環境,需要自己搭建Pytorch伺服器。考慮到便利,本文直接利用Flask—一個輕量的Python伺服器框架。
利用Flask傳遞Numpy Array
因為在整個Pytorch使用過程中,最核心的數據結構就是Tensor,基本構成就是多維數組,所以先實現傳遞Numpy Array。
Client
import requestsimport numpy as npimport jsonaddr = http://localhost:5000test_url = addr + /api/testcontent_type = application/jsonheaders = {content-type: content_type}temp = np.zeros((2, 4)) + 0.1temp = temp.tolist()data = {data: temp}response = requests.post(test_url, json=json.dumps(data), headers=headers)print(response.text)
Server
from flask import Flask, request, Responseimport numpy as npimport jsonapp = Flask(__name__)@app.route(/api/test, methods=[POST])def test(): r = request.json r_json = json.loads(r) data = r_json[data] numpy_data = np.asarray(data) response = {message: Data type:{},Shape:{}.format(type(numpy_data), numpy_data.shape)} response_pickled = json.dumps(response) return Response(response=response_pickled, status=200, mimetype="application/json")# start flask appapp.run(host="0.0.0.0", port=5000)
然後先運行服務端,再運行客戶端。可以得到:
{"message": "Data type:<class numpy.ndarray>,Shape:(2, 4)"}
說明我們成功通過網路傳遞了一個Numpy array。
加入Pytorch模型
模型訓練與本文無關,故不做闡述。隨便拿一個模型舉例。該模型的功能為識別圖中形狀為正方形還是圓形,輸入為3通道32*32的圖像向量。文件結構如下:
-Model
—model.py
—model.pkl
隨機形狀生成器
先寫一個隨機生成正方形/圓形的模塊。
import cv2import numpy as npimport randomimport matplotlib.pyplot as pltdef imshow(img): fig, ax = plt.subplots() fig.set_size_inches(5, 5) ax.axis("off") plt.imshow(img/255) plt.show()class Shape: def __init__(self): self.colors = [ (0, 0, 255), # r (0, 255, 0), # g (255, 0, 0), # b (0, 156, 255), # o (128, 128, 128), # gray (0, 255, 255) # cyan ] self.canvas_size = 100 def make(self,model): img = np.zeros((self.canvas_size, self.canvas_size, 3), dtype=np.float32) * 255 color = self.colors[random.randint(0, 5)] center = [random.randint(40, 60), random.randint(40, 60)] object_size = random.randint(10, 40) if model == retrangle: start = (center[0] - object_size, center[1] - object_size) # painting start point end = (center[0] + object_size, center[1] + object_size) cv2.rectangle(img, start, end, color, -1) if model == circle: cv2.circle(img, (center[0], center[1]), object_size, color, -1) img = img return img/255creator=Shape()circle=creator.make(circle)imshow(circle)retrangle=creator.make(retrangle)imshow(retrangle)
運行可以看到生成了圓形和正方形。
測試下模型
import PILimport torchfrom model import modelfrom shape_maker import Shapefrom torch.autograd import Variableimport torchvision.transforms as transforms# Load structurenet = model.ShapeDetectNetwork()# Load paranet.load_state_dict(torch.load(./model/shapeDetect, map_location=lambda storage, loc: storage))# Create two shapescreator = Shape()circle = creator.make(circle)retrangle = creator.make(retrangle)# transform img array to pytorch tensordef array2tensor(img): img = PIL.Image.fromarray(img.astype(uint8)) trans = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), ]) img = trans(img) return img# 2 classes based on training setclasses = [circle, retrangle]o1 = net(Variable(array2tensor(circle).unsqueeze(0)))o2 = net(Variable(array2tensor(retrangle).unsqueeze(0)))_, idx1 = torch.max(o1.data, 1)_, idx2 = torch.max(o2.data, 1)print(classes[idx1[0]], classes[idx2[0]])
可以得到:
circle retrangle
模型能夠正常運作。接下來嘗試把模型部署到服務端,客戶端向服務端傳送圖像。
Client
import requestsfrom shape_maker import Shapeimport json# Set server addressaddr = http://localhost:5000test_url = addr + /api/test# Set post headercontent_type = application/jsonheaders = {content-type: content_type}# Create 2 shapescreator = Shape()circle = creator.make(circle)retrangle = creator.make(retrangle)# Transform numpy array to listcircle = circle.tolist()retrangle = retrangle.tolist()# wrap them into jsonjson_f1 = json.dumps({data: retrangle})json_f2 = json.dumps({data: circle})# post requestresponse1 = requests.post(test_url, json=json_f1, headers=headers)print(response1.text)response2 = requests.post(test_url, json=json_f2, headers=headers)print(response2.text)
Server
from flask import Flask, request, Responsefrom utilities import array2tensorfrom model import modelimport torchfrom torch.autograd import Variableimport jsonimport numpy as npapp = Flask(__name__)@app.route(/api/test, methods=[POST])def test(): net = model.ShapeDetectNetwork() net.load_state_dict(torch.load(./model/shapeDetect, map_location=lambda storage, loc: storage)) r = request.json r_json = json.loads(r) data = r_json[data] numpy_data = np.asarray(data) o = net(Variable(array2tensor(numpy_data).unsqueeze(0))) classes = [circle, retrangle] _, idx = torch.max(o.data, 1) shape = classes[idx[0]] # response response = { message: The shape is {}.format(shape) } # encode response using jsonpickle response_pickled = json.dumps(response) return Response(response=response_pickled, status=200, mimetype="application/json")# start flask appapp.run(host="0.0.0.0", port=5000)
同樣先運行服務端,再運行客戶端。可以得到:
{"message": "The shape is retrangle"}{"message": "The shape is circle"}
實驗成功!
源碼鏈接:
https://github.com/nofacer/Flask_Pytorch_Server
本文遵守CC-BY 轉載請註明原作者(倪旭彬) https://creativecommons.org/licenses/by-sa/3.0/cn/
推薦閱讀:
TAG:PyTorch | Flask | 深度學習DeepLearning |