深度学习模型部署篇——利用Flask实现深度学习模型部署(三)-灵析社区

我不是魔法师

flask模型部署初探

下面我们就来开始介绍了喔,让我们一起来学一下叭。【这回采用步骤式讲解看看效果】

首先我先来梳理一下代码运行的整体流程,这样大家可能会更清晰一点,如下图所示。

大家要注意的是,我们会有两个.py文件,一个用于服务端开启服务,另一个用于客户端发送请求并接受服务端的返回值。客户端的代码较为简单,这里重点说说服务端的代码运行流程。服务端代码中主要有三个函数,分别为predict、get_prediction、transform_image。当我们运行服务端程序时,app.run启动,服务开启,此时会监听客户端是否发送请求,若检测到客户端发送请求,则会进入predict函数处理这个请求,接着predict函数会调用get_prediction函数,而get_prediction函数会调用transform_image函数。

先给大家介绍代码运行流程,大家再看下面的代码应该就比较清晰了,下文将分为服务端和客户端两部分介绍代码。

服务端

  • 创建一个Flask应用:
 app = Flask(__name__)
  • 准备模型和资源
 # 在指定设备上创建 AlexNet 模型
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 model = AlexNet(num_classes=5).to(device)
 # 加载预训练权重
 model.load_state_dict(torch.load(r'E:\模型部署\checkpoint\AlexNet.pth', map_location='cpu'))
 # 设置模型为评估模式
 model.eval()

注意AlexNet.pth模型是我们在上讲中介绍的花的五分类模型。其实模型部署本质上就是模型的测试,所有我们要将模型设置成评估模式。

  • 定义预测函数
 def transform_image(image_bytes):
     my_transforms = transforms.Compose([transforms.Resize(255),
                                         transforms.CenterCrop(227),
                                         transforms.ToTensor(),
                                         transforms.Normalize(
                                             [0.485, 0.456, 0.406],
                                             [0.229, 0.224, 0.225])])
     image = Image.open(io.BytesIO(image_bytes))
     return my_transforms(image).unsqueeze(0)

大家还是要注意一下这里,我们将图像resize到227×227大小,这是因为AlexNet的输入要求,不清楚的可以看下这篇博客:深度学习经典网络模型汇总1——LeNet、AlexNet、ZFNet

还要注意这里最后使用unsqueeze(0)方法添加了一个batch维度信息。

  • 定义预测函数
 def get_prediction(image_bytes):
     # 记录该帧开始处理的时间
     start_time = time.time()
 ​
     # 转换图像数据为模型输入格式
     tensor = transform_image(image_bytes=image_bytes)
 ​
     # 通过模型进行前向传播
     outputs = model.forward(tensor)
 ​
     # 对模型输出进行 softmax 操作
     pred_softmax = F.softmax(outputs, dim=1)
 ​
     # 获取前N个预测结果
     top_n = torch.topk(pred_softmax, 5)
     pred_ids = top_n.indices[0].cpu().detach().numpy()  # 转换为NumPy数组
     confs = top_n.values[0].cpu().detach().numpy() * 100  # 转换为NumPy数组,并转换为百分比
 ​
     # 记录该帧处理完毕的时间
     end_time = time.time()
     # 计算每秒处理图像帧数FPS
     FPS = 1 / (end_time - start_time)
 ​
     # 载入类别和对应 ID
     idx_to_labels = np.load('idx_to_labels1.npy', allow_pickle=True).item()
 ​
     results = []  # 用于存储结果的列表
     for i in range(5):
         class_name = idx_to_labels[pred_ids[i]]  # 获取类别名称
         confidence = confs[i]  # 获取置信度
         text = '{:<6} {:>.3f}'.format(class_name, confidence)
         results.append(text)  # 将结果添加到列表中
 ​
     return results, FPS  # 返回包含类别和置信度的列表

这步其实和我上一讲的内容是差不多的,这个函数主要是对图像进行推理,并输出推理的结果和推理时间。

  • 定义接收上传图片并预测的路由
 @app.route('/predict', methods=['POST'])
 def predict():
     if request.method == 'POST':
         file = request.files['file']
         img_bytes = file.read()
         class_info, FPS = get_prediction(image_bytes=img_bytes)
         response_data = {'class_info': class_info, 'FPS': FPS}
         return jsonify(response_data)

解释一下这个@app.route('/predict', methods=['POST'])叭,它是一个 Flask 路由装饰器,它告诉 Flask 在接收到 /predict 路径上的 POST 请求时,会调用下面定义的predict()函数来处理这个请求。

  • 启动 Flask 应用
 if __name__ == '__main__':
     app.run()

app.run()是 Flask 应用的运行函数,它启动了一个本地的开发服务器,用于监听来自客户端的请求并响应。

客户端

  • 发送请求
 # 发送 POST 请求到 Flask 服务器
 resp = requests.post("http://localhost:5000/predict",
                      files={"file": open('flower.jpg', 'rb')})
  • 处理服务端返回结果
 # 检查服务器响应状态码
 if resp.status_code == 200:  # 如果响应状态码为 200 表示成功
     # 从响应中提取 JSON 数据
     response_data = resp.json()
     class_info = response_data['class_info']  # 提取预测结果信息
     fps = response_data['FPS']  # 提取处理帧数信息
 ​
     # 输出预测的类别信息
     for info in class_info:
         print(info)
 ​
     # 输出处理帧数信息
     print("FPS:", fps)
 else:  # 如果响应状态码不是 200,则表示出现了错误
     print("Error:", resp.text)  # 输出错误信息

运行结果

首先我们要运行服务端的程序test_alexnet.py开启服务,可以通过anaconda终端执行,如下:

接着我们可以新开一个终端执行客户端程序sent_post.py发送请求,或者直接在pycharm上执行程序,如下:

从上图可以看出郁金香的识别率达到了99.964,哦,忘记给大家看我测试的图片了,是这张喔:

我们也可以发现FPS为15.6,但是我们一般不取第一次的结果,因为会进行初始化等操作,影响速度,我们在运行几次看看FPS结果。

后面几次FPS大概稳定在20-21左右。

通过ONNX加速模型部署

在上两讲我们介绍了通过ONNX加速模型部署,那么这里我们自然也要试一试,看看速度有没有加速腻。

那么其实这一部分的代码和上一小节是非常类似的,我将服务端代码写在了test_alexnet_onnx.py中,客户端代码没有改变,仍然是sent_post.py。我把主要修改的地方说明一下,其它一些细节大家可以自己去github下载源码查看。

  • 加载ONNX模型
 # 加载模型
 model = AlexNet(num_classes=5).to(device)
 def load_onnx_model():
     global ort_session
     ort_session = onnxruntime.InferenceSession(r'E:\模型部署\Alex_flower5.onnx')
 ​
 # 在应用启动时加载 ONNX 模型
 load_onnx_model()

这里加载的是ONNX模型,至于如何得到ONNX模型可以看我上一讲的内容。

  • ONNX推理引擎推理
 ort_inputs = {'input': tensor.numpy()}
 pred_logits = ort_session.run(['output'], ort_inputs)[0]


剩下的基本都差不多了,我们直接来看看运行的效果叭。

首先开启服务,等待请求,如下:

然后运行客户端代码,发送请求,获得结果:

可以看到预测精度和之前使用pytorch预测时是一致的,但FPS提高到了26.8。当然了,同样的道理,这是第一次运行,FPS会相对较低,我们再运行几次,如下:

可以发现,现在的FPS可以基本稳定在32左右,是不是比之前快了不少呢,大家快去试试叭。

加点佐料

不知道大家发现没有,上面的功能算是实现了通过Flask部署深度学习模型,但是总感觉差点意思,于是准备结合前端来搭建一个稍微好看的界面,通过点击前端的按钮来发送请求。

说干就干,但是好像干不动,因为自己不会前端呀,但是又问题不大,因为我会百度呀,直接找一个前端的代码就好了嘛,于是找到了霹雳吧啦Wz大佬滴代码,对其稍微改进了一下,使用了ONNX进行模型推理 ,并在前端输出FPS信息,代码为main_html_test.py。同样的,一些细节大家详细移步源码查看

我们先来看一看实现的效果叭,然后我再来解释一下如何实现的,效果如下:

enmmmm,开始准备展示动态图的,但是运行录屏工具后,预测的FPS就下降了,所以大家还是看看图片叭。

首先我们运行main_html_test.py程序,会得到如下结果:

点击上图中的链接进入前端界面:

然后点击选择文件,再点击预测,即可显示预测结果和FPS,如下:



上面就是最终的效果啦,最后我来稍微解释代码的整个流程,如下:

  1. 用户在前端界面上选择一个图像文件。
  2. 用户点击预测按钮,触发 test() 函数。
  3. test() 函数使用 AJAX 将图像文件发送到后端的 /predict 路由。
  4. 后端接收到请求,调用 predict() 函数进行图像预测。
  5. predict() 函数返回预测结果和FPS信息,发送回前端。
  6. 前端接收到后端返回的数据,将预测结果和FPS信息展示在页面上。

关于test()函数的内容如下:

 function test() {
     // 获取选择的文件对象
     var fileobj = $("#file0")[0].files[0];
     console.log(fileobj);
     
     // 创建一个 FormData 对象,用于将文件对象传递到后端
     var form = new FormData();
     form.append("file", fileobj);
     
     // 初始化变量用于存储分类结果和FPS信息
     var flower='';
     var fps = '';
 ​
     // 发送AJAX请求到后端的predict路由
     $.ajax({
         type: 'POST',
         url: "predict",
         data: form,
         async: false,
         processData: false,
         contentType: false,
         success: function (data) {
             console.log(data);
             
             // 从返回的数据中获取分类结果和FPS信息
             var results = data.class_info;
             fps = data.FPS;
             
             console.log(results);
             console.log(fps);
             
             // 生成分类结果的HTML字符串
             var flowerResult = '';
             results.forEach(e => {
                 flowerResult += `<div style="border-bottom: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">${e}</div>`;
             });
 ​
             // 生成FPS信息的HTML字符串
             var fpsResult = `<div style="border-top: 1px solid #CCCCCC;line-height: 60px;font-size:16px;">FPS: ${fps.toFixed(2)}</div>`;
 ​
             // 将生成的分类结果和FPS信息插入到页面元素中
             document.getElementById("out").innerHTML = flowerResult + fpsResult;
         },
         error: function () {
             console.log("后台处理错误");
         }
     });
 }


阅读量:1563

点赞量:0

收藏量:0