前言

本文只是分享思路,不提供可完整运行的项目代码

onnx部署

以目标检测类模型为例,该类模型会输出类别信息,置信度,包含检测框的4个坐标信息

但不是所有的onnx模型都能在微信小程序部署,有些算子不支持,这种情况需要点特殊操作。

微信小程序提供的接口相当于使用onnxruntime的接口运行onnx模型,我们要做的就是将视频帧数据(包含RGBA的一维像素数组)转换成对应形状的数组(比如3*224*224的一维Float32Array),然后调用接口并将图像输入得到运行的结果(比如一个1*10*6的一维Float32Array,代表着10个预测框的类别,置信度和框的4个坐标),然后将结果处理(比如行人检测,给置信度设置一个阈值0.5,筛选置信度大于阈值的数组的index,然后按照index取出相应的类别和框坐标),最后在wxml中显示类别名或置信度或在canvas绘制框。

代码框架

这里采用的是实时帧数据,按预设频率调用一帧数据并后处理得到结果

初始化session

首先得将onnx上传至云端,获得一个存储路径(比如cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/rtdetrWorker.onnx)

当用户首次使用该小程序时,手机里没有onnx模型的存储,需要从云端下载;而已经非第一次使用该小程序的用户手机里已经保存了之前下载的onnx模型,就无需下载。所以此处代码逻辑是需要检测用户的存储里是否有该onnx模型,不存在就下载,下载完并保存模型文件后就执行下一步;存在就直接执行下一步。

InitSession()

{

return new Promise(resolve=>{

const cloudPath = 'cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/mobilnet.onnx'

const lastindex=cloudPath.lastIndexOf('/')

const filename=cloudPath.substring(lastindex+1)

const modelPath = `${wx.env.USER_DATA_PATH}/`+filename;

// 判断之前是否已经下载过onnx模型

wx.getFileSystemManager().access({

path: modelPath,

success: (res) =>

{

console.log("文件已经存在")

// 创建session

this.createInferenceSession(modelPath)

// 监听帧,频率为1秒1次

setInterval(this.oneFrame, 1000)

resolve()

},

fail: (res) => {

// 文件不存在

console.error(res)

wx.cloud.init();

console.log("开始下载模型");

// 调用自定义函数下载文件

this.downloadFile(cloudPath, function(r) {

console.log(`下载进度:${r.progress}%,已下载${r.totalBytesWritten}B,共${r.totalBytesExpectedToWrite}B`)

}).then(result => {

// 保存模型到本地

wx.getFileSystemManager().saveFile({

tempFilePath:result.tempFilePath,

filePath: modelPath,

success: (res) => { // 注册回调函数

console.log(res)

const modelPath = res.savedFilePath;

console.log("保存模型到路径: " + modelPath)

// 创建session

this.createInferenceSession(modelPath)

// 监听帧,频率为1秒1次

setInterval(this.oneFrame, 1000)

resolve()

},

fail(res) {

console.error(res)

}

})

});

}

})

})

自定义的下载文件函数

downloadFile(fileID, onCall = () => {}) {

return new Promise((resolve, reject) => {

const task = wx.cloud.downloadFile({

fileID,

success: res => resolve(res),

})

task.onProgressUpdate((res) => {

if (onCall(res) == false) {

task.abort()

}

})

})

},

自定义创建session的函数

createInferenceSession(modelPath) {

return new Promise((resolve, reject) => {

this.session = wx.createInferenceSession({

model: modelPath,

precisionLevel : 4,

allowNPU : false,

allowQuantize: false,

});

// 监听error事件

this.session.onError((error) => {

console.error(error);

reject(error);

});

this.session.onLoad(() => {

resolve();

});

})

},

自定义处理帧函数

就是上面初始化session步骤里面 创建session后 按预设频率执行的函数

开启相机监听,在回调函数内获取帧数据、处理帧数据、开始推理、关闭监听

oneFrame(){

const context=wx.createCameraContext()

const camCallback=(frame)=>{

// 处理图片数据

var dstInput=new Float32Array(this.data.imageChannel*this.data.imageWidth*this.data.imageHeight)

this.preProcess(frame,dstInput)

// 推理得到结果

this.infer(dstInput)

// 关闭监听

listener.stop()

}

const listener=context.onCameraFrame(camCallback)

listener.start()

},

自定义的图像处理函数

该函数接收帧数据(RGBA一维数组)和在外面初始化的Float32Array数组,执行归一化、去除透明度通道。

preProcess(frame, dstInput) {

return new Promise((resolve, reject) =>

{

const origData = new Uint8Array(frame.data);

const hRatio = frame.height / this.data.imageHeight;

const wRatio = frame.width / this.data.imageWidth;

const origHStride = frame.width * 4;

const origWStride = 4;

const mean = [0.485, 0.456, 0.406]

// Reverse of std = [0.229, 0.224, 0.225]

const reverse_div = [4.367, 4.464, 4.444]

const ratio = 1 / 255.0

const normalized_div = [ratio / reverse_div[0], ratio * reverse_div[1], ratio * reverse_div[2]];

const normalized_mean = [mean[0] * reverse_div[0], mean[1] * reverse_div[1], mean[2] * reverse_div[2]];

var idx = 0;

for (var c = 0; c < this.data.imageChannel; ++c)

{

for (var h = 0; h < this.data.imageHeight; ++h)

{

const origH = Math.round(h * hRatio);

const origHOffset = origH * origHStride;

for (var w = 0; w < this.data.imageWidth; ++w)

{

const origW = Math.round(w * wRatio);

const origIndex = origHOffset + origW * origWStride + c;

const val = origData[origIndex] * (normalized_div[c]) - normalized_mean[c];

dstInput[idx] = val;

idx++;

}

}

}

resolve();

});

},

自定义的推理函数

推理接口接收数个键值对input,具体需要参照自己的onnx模型,在Netron查看相应的模型信息

我这里只有1个输入,对应的名字为"images",接收(1,3,300,300)性质的图像数组

我这里有2个输出,对应的名字是“794”和“output”,分别对应相应类别的置信度(1*10*2)&框的坐标信息(1*10*4),这里的10对应10个预测框,2代表有2个类别

接着就是获取某一类别(比如前景)最大置信度的索引并取出其框的信息

然后绘制在canvas上

当然也可以设置阈值比如0.5,前景类别置信度大于0.5的就保留,然后根据得到的index取出框的信息,绘制到canvas上,或者只取类别和对应的置信度,根据自己的需求处理

infer(imgData){

this.session.run({

"images":{

shape: [1, this.data.imageChannel, this.data.imageHeight, this.data.imageWidth],

data: imgData.buffer,

type: 'float32',

}

}).then((res)=>{

let box = new Float32Array(res.output.data)

let score = new Float32Array(res[794].data)

// console.log(box)

let num = new Float32Array(score)

var maxVar = num[0];

var index = 0;

for (var i = 0; i < num.length; i+=2)

{

if (maxVar < num[i])

{

maxVar = num[i]

index = i/2

}

}

this.setData({

xmin:box[index*4],

xmax:box[index*4+2],

ymin:box[index*4+1],

ymax:box[index*4+3]

})

this.drawRectangle()

})

},

自定义的绘制框函数

这里用的是微信新的canvas接口

drawRectangle(){

wx.createSelectorQuery().select('#myCanvas')

.fields({node:true,size:true})

.exec((res)=>{

const canvas=res[0].node

const ctx=canvas.getContext('2d')

const dpr = wx.getSystemInfoSync().pixelRatio

canvas.width = res[0].width * dpr

canvas.height = res[0].height * dpr

ctx.scale(dpr, dpr)

ctx.strokeStyle='red'

ctx.lineWidth=2

console.log(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax)

ctx.strokeRect(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax,canvas.width,canvas.height)

})

}

代码总览

index.js

Page({

session:null,

data: {

src : '',

windowWidth:0,

imageWidth : 300,

imageHeight : 300,

imageChannel : 3,

xmin:0,

ymin:0,

xmax:0,

ymax:0

},

onLoad(){

this.setData({

windowWidth:wx.getSystemInfoSync().windowWidth*0.9

})

this.InitSession()

},

oneFrame(){

const context=wx.createCameraContext()

const camCallback=(frame)=>{

// 处理图片数据

var dstInput=new Float32Array(this.data.imageChannel*this.data.imageWidth*this.data.imageHeight)

this.preProcess(frame,dstInput)

// 推理得到结果

this.infer(dstInput)

// 关闭监听

listener.stop()

}

const listener=context.onCameraFrame(camCallback)

listener.start()

},

downloadFile(fileID, onCall = () => {}) {

return new Promise((resolve, reject) => {

const task = wx.cloud.downloadFile({

fileID,

success: res => resolve(res),

})

task.onProgressUpdate((res) => {

if (onCall(res) == false) {

task.abort()

}

})

})

},

preProcess(frame, dstInput) {

return new Promise((resolve, reject) =>

{

const origData = new Uint8Array(frame.data);

const hRatio = frame.height / this.data.imageHeight;

const wRatio = frame.width / this.data.imageWidth;

const origHStride = frame.width * 4;

const origWStride = 4;

const mean = [0.485, 0.456, 0.406]

// Reverse of std = [0.229, 0.224, 0.225]

const reverse_div = [4.367, 4.464, 4.444]

const ratio = 1 / 255.0

const normalized_div = [ratio / reverse_div[0], ratio * reverse_div[1], ratio * reverse_div[2]];

const normalized_mean = [mean[0] * reverse_div[0], mean[1] * reverse_div[1], mean[2] * reverse_div[2]];

var idx = 0;

for (var c = 0; c < this.data.imageChannel; ++c)

{

for (var h = 0; h < this.data.imageHeight; ++h)

{

const origH = Math.round(h * hRatio);

const origHOffset = origH * origHStride;

for (var w = 0; w < this.data.imageWidth; ++w)

{

const origW = Math.round(w * wRatio);

const origIndex = origHOffset + origW * origWStride + c;

const val = origData[origIndex] * (normalized_div[c]) - normalized_mean[c];

dstInput[idx] = val;

idx++;

}

}

}

resolve();

});

},

infer(imgData){

this.session.run({

"images":{

shape: [1, this.data.imageChannel, this.data.imageHeight, this.data.imageWidth],

data: imgData.buffer,

type: 'float32',

}

}).then((res)=>{

let box = new Float32Array(res.output.data)

let score = new Float32Array(res[794].data)

// console.log(box)

let num = new Float32Array(score)

var maxVar = num[0];

var index = 0;

for (var i = 0; i < num.length; i+=2)

{

if (maxVar < num[i])

{

maxVar = num[i]

index = i/2

}

}

this.setData({

xmin:box[index*4],

xmax:box[index*4+2],

ymin:box[index*4+1],

ymax:box[index*4+3]

})

this.drawRectangle()

})

},

InitSession()

{

return new Promise(resolve=>{

const cloudPath = 'cloud://cloud1-8gcwcxqrb8722e9e.636c-cloud1-8gcwcxqrb8722e9e-1324077753/mobilnet.onnx'

const lastindex=cloudPath.lastIndexOf('/')

const filename=cloudPath.substring(lastindex+1)

const modelPath = `${wx.env.USER_DATA_PATH}/`+filename;

// 判断之前是否已经下载过onnx模型

wx.getFileSystemManager().access({

path: modelPath,

success: (res) =>

{

console.log("file already exist at: " + modelPath)

this.createInferenceSession(modelPath)

setInterval(this.oneFrame, 1000)

resolve()

},

fail: (res) => {

console.error(res)

wx.cloud.init();

console.log("begin download model");

this.downloadFile(cloudPath, function(r) {

console.log(`下载进度:${r.progress}%,已下载${r.totalBytesWritten}B,共${r.totalBytesExpectedToWrite}B`)

}).then(result => {

wx.getFileSystemManager().saveFile({

tempFilePath:result.tempFilePath,

filePath: modelPath,

success: (res) => { // 注册回调函数

console.log(res)

const modelPath = res.savedFilePath;

console.log("save onnx model at path: " + modelPath)

this.createInferenceSession(modelPath)

setInterval(this.oneFrame, 1000)

resolve()

},

fail(res) {

console.error(res)

}

})

});

}

})

})

},

createInferenceSession(modelPath) {

return new Promise((resolve, reject) => {

this.session = wx.createInferenceSession({

model: modelPath,

precisionLevel : 4,

allowNPU : false,

allowQuantize: false,

});

// 监听error事件

this.session.onError((error) => {

console.error(error);

reject(error);

});

this.session.onLoad(() => {

resolve();

});

})

},

drawRectangle(){

wx.createSelectorQuery().select('#myCanvas')

.fields({node:true,size:true})

.exec((res)=>{

const canvas=res[0].node

const ctx=canvas.getContext('2d')

const dpr = wx.getSystemInfoSync().pixelRatio

canvas.width = res[0].width * dpr

canvas.height = res[0].height * dpr

ctx.scale(dpr, dpr)

ctx.strokeStyle='red'

ctx.lineWidth=2

console.log(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax)

ctx.strokeRect(this.data.xmin, this.data.ymin, this.data.xmax, this.data.ymax,canvas.width,canvas.height)

})

}

})

 index.wxss

.c1{

width: 100%;

align-items: center;

text-align: center;

display: flex;

flex-direction: column;

}

.camera{

width: 100%;

}

#myCanvas{

width: 100%;

height: 100%;

}

index.wxml

flask部署

微信小程序负责把图像数据或帧数据传到服务器,在服务器用falsk搭建相关模型运行环境,将接收到的图像数据或帧数据预处理后输入模型里,在将结果返回给微信小程序,微信小程序再显示结果。

我这里给的例子是传送帧数据的,也就是实时检测。

前端

在前端,获得帧数据后,因为帧数据的格式是一维RGBA数组,为了将其转成png,方便服务器处理,把帧数据绘制到画布上,再导出为png送入服务器。接收到服务器的结果后,将检测框绘制到相机的界面,需要在标签里加上标签,然后画上矩形框,并在下方显示分类结果。

主体代码框架

Page({

data: {

windowWidth:wx.getSystemInfoSync().windowWidth*1.33,

boxNum:'',

},

// 自定义实时检测的频率,这里是800ms检测一次

// http://t.csdnimg.cn/rLLLw 具体见此地址

onLoad(){

setInterval(this.oneProcessFrame, 800);

},

})

oneProcessFrame(){

const context = wx.createCameraContext();

const data={"pngData":null}

const CamFramCall = (frame)=>{

// 调整显示页面的相机画面,为了使显示页面的横宽比等于frame数据的横宽比

// 在画框的时候,模型跑出来的检测框坐标是相对于输入的图像的大小

// 如果显示画面和输入框的比例不匹配,就会出现检测框不完整或者检测框有部分跑到画面外的情况

// 微信小程序的frame,我没有找到官方提供的可以修改尺寸的API,所以用了这个办法

//当然还有一种思路,将frame进行裁剪,使frame包含的图片信息正好对应显示画面的信息(像素一一对应)

this.setData({

windowWidth:frame.height/frame.width*wx.getSystemInfoSync().windowWidth*0.9

})

// 调用自定义函数将frame转png,然后把png数据绑定到传送给服务器的data

// 再将data传给服务器

// 这里用了异步编程,只有帧数据顺利转成png才发送给服务器,确保模型接收正确数据

this.base64ToPNG(frame).then((pngData)=>{

data["pngData"]=pngData

this.interWithServer(data)

})

// 这里已经处理完一帧的数据,如果不关闭监听相机,那么微信小程序会持续触发相机帧数据回调函数,导致小程序卡顿,资源浪费

console.log('完成一次帧循环')

listener.stop()

}

// 定义相机帧回调函数

const listener = context.onCameraFrame(CamFramCall);

开启监听

listener.start()

},

自定义帧数据转base64的函数

参考http://t.csdnimg.cn/2hc7k

这里增加了异步编程的语句,更合理

base64ToPNG(frame){

return new Promise(resolve=>{

const query = wx.createSelectorQuery()

query.select('#canvas')

.fields({node:true,size:true})

.exec((res)=>{

const canvas=res[0].node

const ctx=canvas.getContext('2d')

canvas.width=frame.width

canvas.height=frame.height

var imageData=ctx.createImageData(canvas.width,canvas.height)

var ImgU8Array = new Uint8ClampedArray(frame.data);

for(var i=0;i

imageData.data[0+i]=ImgU8Array[i+0]

imageData.data[1+i]=ImgU8Array[i+1]

imageData.data[2+i]=ImgU8Array[i+2]

imageData.data[3+i]=ImgU8Array[i+3]

}

ctx.putImageData(imageData,0,0,0,0,canvas.width,canvas.height)

resolve(canvas.toDataURL())

})

})

},

自定义传数据到服务器函数 

interWithServer(data){

const header = {

'content-type': 'application/x-www-form-urlencoded'

};

wx.request({

// 填上自己的服务器地址(下面这个是我的服务器内网地址,仅供展示)

url: 'http://172.16.3.186:5000/predict',

method: 'POST',

header: header,

data: data,

success: (res) => {

console.log(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])

// 调用自定义的画框函数

this.drawRect(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])

},

fail: () => {

wx.showToast({

title: 'Failed to process frame!',

icon: 'none',

});

// 如果与服务器交互失败,清空画布

ctx.clearRect(0,0,canvas.width,canvas.height)

}

});

},

自定义的画检测框函数 

drawRect(x1,y1,x2,y2){

wx.createSelectorQuery().select('#myCanvas')

.fields({node:true,size:true})

.exec((res)=>{

const canvas=res[0].node

const ctx=canvas.getContext('2d')

canvas.width=wx.getSystemInfoSync().windowWidth*0.9

canvas.height=this.data.windowWidth

ctx.clearRect(0,0,canvas.width,canvas.height)

ctx.strokeStyle='red'

ctx.lineWidth=2

ctx.strokeRect(x1,y1,x2,y2)

})

},

index.js

Page({

data: {

windowWidth:wx.getSystemInfoSync().windowWidth*1.33,

boxNum:'',

},

onLoad(){

setInterval(this.oneProcessFrame, 800);

},

oneProcessFrame(){

const context = wx.createCameraContext();

const data={"pngData":null}

const CamFramCall = (frame)=>{

this.setData({

windowWidth:frame.height/frame.width*wx.getSystemInfoSync().windowWidth*0.9

})

this.base64ToPNG(frame).then((pngData)=>{

data["pngData"]=pngData

this.interWithServer(data)

})

console.log('完成一次帧循环')

listener.stop()

}

const listener = context.onCameraFrame(CamFramCall);

listener.start()

},

base64ToPNG(frame){

return new Promise(resolve=>{

const query = wx.createSelectorQuery()

query.select('#canvas')

.fields({node:true,size:true})

.exec((res)=>{

const canvas=res[0].node

const ctx=canvas.getContext('2d')

canvas.width=frame.width

canvas.height=frame.height

var imageData=ctx.createImageData(canvas.width,canvas.height)

var ImgU8Array = new Uint8ClampedArray(frame.data);

for(var i=0;i

imageData.data[0+i]=ImgU8Array[i+0]

imageData.data[1+i]=ImgU8Array[i+1]

imageData.data[2+i]=ImgU8Array[i+2]

imageData.data[3+i]=ImgU8Array[i+3]

}

ctx.putImageData(imageData,0,0,0,0,canvas.width,canvas.height)

resolve(canvas.toDataURL())

})

})

},

drawRect(x1,y1,x2,y2){

wx.createSelectorQuery().select('#myCanvas')

.fields({node:true,size:true})

.exec((res)=>{

const canvas=res[0].node

const ctx=canvas.getContext('2d')

canvas.width=wx.getSystemInfoSync().windowWidth*0.9

canvas.height=this.data.windowWidth

ctx.clearRect(0,0,canvas.width,canvas.height)

ctx.strokeStyle='red'

ctx.lineWidth=2

ctx.strokeRect(x1,y1,x2,y2)

})

},

interWithServer(data){

const header = {

'content-type': 'application/x-www-form-urlencoded'

};

wx.request({

url: 'http://172.16.3.186:5000/predict',

method: 'POST',

header: header,

data: data,

success: (res) => {

console.log(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])

this.drawRect(res.data['xmin'],res.data['ymin'],res.data['xmax'],res.data['ymax'])

},

fail: () => {

wx.showToast({

title: 'Failed to process frame!',

icon: 'none',

});

ctx.clearRect(0,0,canvas.width,canvas.height)

}

});

},

onUnload(){

}

})

 index.wxml

类别:{{className}}

数量:{{boxNum}}

index.wxss

.c1{

width: 100%;

align-items: center;

text-align: center;

display: flex;

flex-direction: column;

}

.camera{

width: 100%;

}

#myCanvas{

width: 100%;

height: 100%;

}

#canvas{

width: 100%;

}

 

后端

接收数据,预处理图像,送入模型,得到初始结果,转化初始结果得到最终结果,返回数据到前端

这里仅作演示,不提供完整项目运行代码和依赖项

from deploy.infer import Detector

from PIL import Image

import cv2

import numpy as np

import io

from gevent import monkey

import base64

from flask import Flask, jsonify, request

from gevent.pywsgi import WSGIServer

monkey.patch_all()

app = Flask(__name__)

model_dir = "inferer2 fewshot\infer" # 模型路径

save_path = "output" # 推理结果保存路径

# 推理参数设置

detector = Detector(

model_dir,

device='CPU',

run_mode='paddle',

trt_min_shape=1,

trt_max_shape=1280,

trt_opt_shape=640,

trt_calib_mode=False,

cpu_threads=1,

enable_mkldnn=False,

enable_mkldnn_bfloat16=False,

output_dir=save_path,

threshold=0.1)

// 推理函数,接收预处理后的数据,返回最终结果

def infer_start(img, threshold=0.2):

results = detector.predict_image([img[:, :, ::-1]], visual=False)

np_boxes=results['boxes']

expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)

np_boxes = np_boxes[expect_boxes, :]

if len(np_boxes)>0:

for dt in np_boxes:

clsid, bbox, score = int(dt[0]), dt[2:], dt[1]

xmin, ymin, xmax, ymax = bbox

print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'

'right_bottom:[{:.2f},{:.2f}]'.format(

int(clsid), score, xmin, ymin, xmax, ymax))

return jsonify({"class_name":"行人","prob":float(score),"xmin":int(xmin),"ymin":int(ymin),"xmax":int(xmax),"ymax":int(ymax)})

else:

return jsonify({"class_name":"未检测到红火蚁","prob":0,"xmin":0,"ymin":0,"xmax":0,"ymax":0})

// 交互主函数

@app.route('/predict', methods=['POST'])

def predict():

if request.method == 'POST':

// 得到png数据,进行预处理

img_base64 = request.form.get('frameData')

if img_base64!='':

img_base64 = img_base64.replace("data:image/png;base64,", "")

img_base64 = base64.b64decode(img_base64)

img = Image.open(io.BytesIO(img_base64))

img=img.convert('RGB')

img=np.array(img)

// 调用推理函数并将结果返回

return infer_start(img)

else:

return "数据为空"

if __name__ == '__main__':

server = WSGIServer(('0.0.0.0', 5000), app)

server.serve_forever()

相关文章

评论可见,请评论后查看内容,谢谢!!!
 您阅读本篇文章共花了: