tensorflow.js官方特别支持了微信小程序,看tfjs-core,fjs-backend-webgl等等模块的dist下都一个单独的miniprogram目录。
还特别提供了一个微信小程序的插件:https://github.com/tensorflow/tfjs-wechat
tensorflow.js对微信小游戏还不支持,但是可以类似的实现,但是效率不高;
小游戏不支持插件,直接使用tfjs-webchat源码,我名称改为tfjs-plugin。
试试一个头部姿势,左右,点头控制:
1、fetch等的适配
基本的引用:
let sysInfo=wx.getSystemInfoSync();
const fetchWechat = require('fetch-wechat');
//window.fetch=fetchWechat.fetchFunc();
const tf_core = require('@tensorflow/tfjs-core');
let tf_poseNet=require('@tensorflow-models/posenet');
const tf_webgl = require('@tensorflow/tfjs-backend-webgl');
const tf_plugin = require('./js/libs/@tensorflow/tfjs-plugin/index.js');
//不能用主屏幕,组屏幕是2d的
let tf_canvas=wx.createCanvas();
///必须指定webGL 1.0版本,微信小游戏只实现了这个版本,不支持2.0
tf_core.ENV.flagRegistry.WEBGL_VERSION.evaluationFn = function() {return 1;}
//tf_core.ENV.set('WEBGL_PACK', false);
tf_plugin.configPlugin({
//backendName:'wechat-webgl',
fetchFunc: fetchWechat.fetchFunc(),
tf:tf_core,
webgl:tf_webgl,
canvas: tf_canvas
},false);
2、模型加载
tf_poseNet加载模型有5种方法:
(1)从官方地址
tf_poseNet.load()不带参数时,是从storage.googleapis.com读取的模型:
- https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/075/model-stride16.json
- https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/075/group1-shard1of2.bin
- https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/075/group1-shard2of2.bin
(2)从别的服务器或者你自己的服务器
let cfg={
architecture: 'MobileNetV1',
outputStride: 16,//越大速度越快
inputResolution: 193,//越小速度越快
multiplier: 0.5//越小越快
}
//cfg.modelUrl='https://ai.flypot.cn/models/posenet/model.json';
cfg.modelUrl='https://cnpmcore.oss-accelerate.aliyuncs.com/binaries/tfjs-models/savedmodel/posenet/mobilenet/float/050/model-stride16.json';
tf_poseNet.load(cfg)
(3)先尝试从本地缓存加载,失败时从服务器加载,加载成功后保存到本地缓存:
const POSENET_URL = 'https://ai.flypot.cn/models/posenet/model.json';
const FILE_STORAGE_PATH='poseNet';
const fsm=wx.getFileSystemManager();
const fileStorageHandler = tf_plugin.fileStorageIO(
FILE_STORAGE_PATH, fsm);
let g_model=null;
function loadModel(options){
if (g_model){
if (options.success) options.success(g_model);
return;
}
let cfg={
architecture: 'MobileNetV1',
outputStride: 16,//越大速度越快
inputResolution: 193,//越小速度越快
multiplier: 0.5//越小越快
}
function loadFromLocal(){
console.log('load model from local................');
wx.showLoading({
title: '从本地加载模型...',
mask:true
});
//https://github.com/tensorflow/tfjs-models/tree/master/posenet
cfg.modelUrl=fileStorageHandler;
/**
* tf_poseNet.load()可以不传参数,使用内置模型。
* 如果传cfg参数,会读取3个文件:
* (1)posenet_info.json
* (2)posenet_model_without_weight.json
* (3)posenet_weight_data
*/
//tf_poseNet.load(cfg)
tf_poseNet.load()
.then(function(model){
console.log('model loaded');
wx.hideLoading();
//console.log(model);
g_model = model;
if (options.success){
options.success(model);
}
},function(err){
console.log(err);
loadFromServer();
});
}
function loadFromServer(){
console.log('load model from server................');
wx.showLoading({
title: '从服务器加载模型...',
mask:true
});
cfg.modelUrl=POSENET_URL;
tf_poseNet.load(cfg)
.then(function(model){
console.log('model loaded');
wx.hideLoading();
//console.log(model);
model.baseModel.model.save(fileStorageHandler);
g_model = model;
if (options.success){
options.success(model);
}
},function(err){
console.log(err);
wx.hideLoading();
if (options.fail){
options.fail(err);
}
});
}
loadFromLocal();
}
注意保存的的模型不是原始文件了,变成了另外3个文件:
- posenet_info.json
- posenet_model_without_weight.json
- posenet_weight_data
(4)只从本地加载
在windows上调试时,前面第(3)种方法保存到本地的模型文件,不在源码目录中,你可以用everything等工具搜索一下posenet_model_without_weight.json文件,可能就找到了。可能在:C:\Users\musta\AppData\Local\微信开发者工具\User Data\8bd760e6f7c30cca133c1a584f36db58\WeappSimulator\WeappFileSystem\o6zAJs8Lo1aWsQT2veP8wlUxz6kg\wx283e6dfb50d72a13\usr\tensorflowjs_models
把posenet_weight_data改为posenet_weight_data.bin。因为微信小游戏不认上传的没有扩展名的文件。
- posenet_info.json
- posenet_model_without_weight.json
- posenet_weight_data.bin
把找到的3个文件放在代码目录asset/tensorflow下,共2.24M:
cfg.modelUrl=tf_plugin.fileStorageIO(
'/assets/tensorflow/posenet', fsm);
tf_poseNet.load(cfg)
但是会报错:
Error: readFile:fail no such file or directory http://usr/tensorflowjs_models//assets/tensorflow/posenet_info.json
需要修改一下@tensorflow\tfjs-plugin\utils\file_storage.js:
/**
* 没有扩展名的文件,微信小游戏上传不认
*/
// var WEIGHT_DATA_SUFFIX = 'weight_data';//delete by wxh
var WEIGHT_DATA_SUFFIX = 'weight_data.bin';//add by wxh
// function getModelPaths(prefix) {
// return {
// info: [MODEL_PATH, prefix + "_" + INFO_SUFFIX].join(PATH_SEPARATOR),
// modelArtifactsWithoutWeights: [MODEL_PATH, prefix + "_" + MODEL_SUFFIX].join(PATH_SEPARATOR),
// weightData: [MODEL_PATH, prefix + "_" + WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
// };
// }//delete by wxh
function getModelPaths(prefix) {
let sPath=MODEL_PATH;
if (prefix.startsWith('/')){
sPath='';
prefix=prefix.substring(1);
}
return {
info: [sPath, prefix + "_" + INFO_SUFFIX].join(PATH_SEPARATOR),
modelArtifactsWithoutWeights: [sPath, prefix + "_" + MODEL_SUFFIX].join(PATH_SEPARATOR),
weightData: [sPath, prefix + "_" + WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR),
};
}
(5)开启静态资源服务器
把原始模型文件放在静态资源目录中,如:assets/posenet。
- model-stride16.json
- group1-shard1of2.bin
- group1-shard2of2.bin
3个文件共4.84M。
//cfg.modelUrl='http://localhost:8001/posenet/model-stride16.json';
cfg.modelUrl='http://192.168.2.112:8001/posenet/model-stride16.json';
tf_poseNet.load(cfg)
从本地加载模型文件时,总体程序大小可能超过小游戏主包4M大小限制,可以分包处理。
注意:这种模式只是在IDE的PC上启动了一个http server,给我们测试用的,当手机用wifi在同一网段时是可以访问到这个server的,但是并不是说在手机内部启动了一个内置的http server。当IDE关闭时,server就关闭了。
3、打开摄像头
载入模型后,打开camera,开始侦测,没有办法直接画在摄像头的canvas上,所以摄像头的canvas设置为1x1大小,相当于隐藏起来。
因为每次人体姿势检测在手机上大概要花80-110ms,很慢,所以不能摄像头视频每帧都检测,10帧检测一次,否则很卡:
let g_camera=null;
let frameIndex=0;
let canvas_camera,ctx_camera;
let camera_imageData;
function startPoseNetControl(options){
function doIt(){
//demoDetect();
openCamera();
if (options.success) options.success();
options.complete();
}
if (g_model) {
doIt();
return;
}
loadModel({
success:function(model){
// wx.showModal({
// content: 'model loaded.'
// });
doIt();
},
fail:function(err){
wx.showModal({
content: JSON.stringify(err)
});
if (options.fail) options.fail(err);
options.complete();
}
});
}
function stopPoseNetControl(options){
if (g_camera){
g_camera.destroy();
g_camera=null;
}
canvas_camera=null;
if (options.success) options.success();
options.complete();
}
function openCamera(){
g_camera=wx.createCamera({
width:1,//不影响onFrame返回的width和height
height:1,
devicePosition:'front',
size:'small',
flash:'off',
success:function(res){
console.log('camera opened.');
console.log(res);
g_camera.listenFrameChange();
},
fail:function(e){
console.log('camera open fail:',e);
}
});
g_camera.onCameraFrame(function(frame){
//console.log(frame.width);
frameIndex++;
if (frameIndex<10) return;
if (!canvas_camera){
canvas_camera=wx.createCanvas();
canvas_camera.width=frame.width;
canvas_camera.height=frame.height;
ctx_camera=canvas_camera.getContext('2d');
camera_imageData = ctx_camera.createImageData(frame.width,frame.height);
}
let pixels=new Uint8Array(frame.data);
detectFrame({
pixels:pixels,
width:frame.width,
height:frame.height,
success:function(res){
camera_imageData.data.set(pixels);
ctx_camera.putImageData(camera_imageData,0,0);
displayResult(res);
frameIndex=0;
},
fail:function(err){
frameIndex=0;
}
});
});
}
4、检测人体姿势
调用人体姿势检测:
function detectFrame(options){
//let t1=new Date();
g_model.estimateSinglePose({
data:options.pixels,
width:options.width,
height:options.height
}, {
flipHorizontal: false
})
.then(function(res){
//每次检测,在pc模拟器中:第1次要800-900ms,第2次只要60-90ms
//手机上:第2次后每次80-110ms
options.success(res);
},function(err){
options.fail(err);
});
}
let g_noseEys_distance=999999;
function displayResult(pose){
//console.log(pose);
const minPoseConfidence = 0.3;
const minPartConfidence = 0.3;
if (pose.score >= minPoseConfidence) {
drawKeypoints(pose.keypoints, minPartConfidence, ctx_camera);
//drawBoundingBox(pose.keypoints,ctx_camera);
//drawSkeleton(pose.keypoints, minPartConfidence, ctx_camera);
if (!wx.tmGlobal.isAllowControl()) return;
//鼻子高度-左眼高度
let dy1=Math.round(pose.keypoints[0].position.y-pose.keypoints[1].position.y);
//左眼高度-右眼高度
let dy2=pose.keypoints[1].position.y-pose.keypoints[2].position.y;
//console.log(dy2);
if (dy1-g_noseEys_distance>10 && dy2>-10 && dy2<10){
//低头时,鼻子-左眼间距加大
wx.tmGlobal.webGL.releaseCurrentBall();
}
else{
wx.tmGlobal.webGL.moveCurrentBall(dy2);
}
g_noseEys_distance=dy1;
}
}
const color = 'aqua';
const boundingBoxColor = 'red';
const lineWidth = 2;
function drawKeypoints(keypoints, minConfidence, ctx, scale = 1) {
for (let i = 0; i < keypoints.length; i++) {
const keypoint = keypoints[i];
if (keypoint.score < minConfidence) {
continue;
}
const { y, x } = keypoint.position;
drawPoint(ctx, y * scale, x * scale, 5, color);
}
}
function drawPoint(ctx, y, x, r, color) {
ctx.beginPath()
ctx.arc(x, y, r, 0, 2 * Math.PI);
ctx.fillStyle = color;
ctx.fill();
ctx.stroke();
}