飞天神龟 2018-03-20 04:42 采纳率: 33.3%
浏览 1257
已结题

请问opencv如何加载bvlc_reference_caffenet

如下代码是根据opencv加载googlenet的代码修改的,用来调用自己训练的caffenet,可是根本不能输出正确的识别结果。

 #include <opencv2/dnn.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/core/utils/trace.hpp>
using namespace cv;
using namespace cv::dnn;

#include <fstream>
#include <iostream>
#include <cstdlib>
using namespace std;

//寻找出概率最高的一类
static void getMaxClass(const Mat &probBlob, int *classId, double *classProb)
{
    Mat probMat = probBlob.reshape(1, 1);
    Point classNumber;

    minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
    *classId = classNumber.x;
}
//从标签文件读取分类 空格为标志
static std::vector<String> readClassNames(const char *filename = "label.txt")
{
    std::vector<String> classNames;

    std::ifstream fp(filename);
    if (!fp.is_open())
    {
        std::cerr << "File with classes labels not found: " << filename << std::endl;
        exit(-1);
    }

    std::string name;
    while (!fp.eof())
    {
        std::getline(fp, name);
        if (name.length())
            classNames.push_back(name.substr(name.find(' ') + 1));
    }
    fp.close();
    return classNames;
}
//主程序
int main(int argc, char **argv)
{
    //初始化
    CV_TRACE_FUNCTION();
    //读取模型参数和模型结构文件
    String modelTxt = "deploy.prototxt";
    String modelBin = "caffe_train_iter_5000.caffemodel";
    //读取图片
    String imageFile = (argc > 1) ? argv[1] : "./ceshi/0.jpg";

    //合成网络
    Net net = dnn::readNetFromCaffe(modelTxt, modelBin);
    //判断网络是否生成成功
    if (net.empty())
    {
        std::cerr << "Can't load network by using the following files: " << std::endl;
        exit(-1);
    }
    cerr << "net read successfully" << endl;

    //读取图片
    Mat img = imread(imageFile);
    imshow("image", img);
    if (img.empty())
    {
        std::cerr << "Can't read image from the file: " << imageFile << std::endl;
        exit(-1);
    }
    cerr << "image read sucessfully" << endl;

    Mat inputBlob = blobFromImage(img, 1.0f, Size(227, 227));

    Mat prob;
    cv::TickMeter t;
    for (int i = 0; i < 10; i++)
    {
        CV_TRACE_REGION("forward");
        //将构建的blob传入网络data层
        net.setInput(inputBlob, "data");
        //计时
        t.start();
        //前向预测
        prob = net.forward("prob");
        //停止计时
        t.stop();
    }

    int classId;
    double classProb;
    //找出最高的概率ID存储在classId,对应的标签在classProb中
    getMaxClass(prob, &classId, &classProb);

    //打印出结果
    std::vector<String> classNames = readClassNames();
    std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
    std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
    //打印出花费时间
    std::cout << "Time: " << (double)t.getTimeMilli() / t.getCounter() << " ms (average from " << t.getCounter() << " iterations)" << std::endl;

    //便于观察结果
    waitKey(0);
    return 0;
}
  • 写回答

1条回答 默认 最新

  • devmiao 2018-03-21 16:38
    关注
    评论

报告相同问题?