程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> 關於C語言 >> SVM數字識別

SVM數字識別

編輯:關於C語言

SVM數字識別


#include "stdafx.h"
#include <fstream>
#include "opencv2/opencv.hpp"
#include <vector>
using namespace std;
using namespace cv;
#define SHOW_PROCESS 1
#define ON_STUDY 1
class NumTrainData
{
public:
    NumTrainData()
    {
        memset(data, 0, sizeof(data));
        result = -1;
    }
public:
    float data[64];
    int result;
};
vector<NumTrainData> buffer;
int featureLen = 64;
void swapBuffer(char* buf)
{
    char temp;
    temp = *(buf);
    *buf = *(buf+3);
    *(buf+3) = temp;
    temp = *(buf+1);
    *(buf+1) = *(buf+2);
    *(buf+2) = temp;
}
void GetROI(Mat& src, Mat& dst)
{
    int left, right, top, bottom;
    left = src.cols;
    right = 0;
    top = src.rows;
    bottom = 0;
    //Get valid area
    for(int i=0; i<src.rows; i++)
    {
        for(int j=0; j<src.cols; j++)
        {
            if(src.at<uchar>(i, j) > 0)
            {
                if(j<left) left = j;
                if(j>right) right = j;
                if(i<top) top = i;
                if(i>bottom) bottom = i;
            }
        }
    }
    //Point center;
    //center.x = (left + right) / 2;
    //center.y = (top + bottom) / 2;
    int width = right - left;
    int height = bottom - top;
    int len = (width < height) ? height : width;
    //Create a squre
    dst = Mat::zeros(len, len, CV_8UC1);
    //Copy valid data to squre center
    Rect dstRect((len - width)/2, (len - height)/2, width, height);
    Rect srcRect(left, top, width, height);
    Mat dstROI = dst(dstRect);
    Mat srcROI = srcRect);
     srcROI.copyTo(dstROI);
}
int ReadTrainData()
{
    Mat src;
    Mat temp = Mat::zeros(8, 8, CV_8UC1);
    Mat m = Mat::zeros(1, 64, CV_8UC1);
    Mat dst;
    NumTrainData rtd;
    const int p_num = 377;
    //圖片對應的數字,存儲在nclass.xml中
    Mat label=cvCreateMat(1, p_num, CV_32SC1);
    FileStorage file("nclass.xml", FileStorage::READ);
    file["data"]>>label;
    file.release();
    cout<<label.size()<<endl;
    int k=0;
    char ch[10];
    int x=0;
    int ff;
    int i, j;
    int re;
    //讀入圖片,將其轉化為行矩陣,存入NumTrainData中
    while (k < p_num)
    {
        cout<<k<<endl;
        string str;
        sprintf(ch,"%d",k);
        str=ch;
        str=str+".jpg";
        //Read data
        imread(str.c_str(),  0);
        GetROI(src, dst);
        resize(dst, temp, temp.size());
        for(i = 0; i<8; i++)
        {
            for(j = 0; j<8; j++)
            {
                ff = temp.at<uchar>(i, j);
                rtd.data[ i*8 + j] = ff;
            }
        }
        re=label.at<int>(0,k);
        rtd.result=re;
        buffer.push_back(rtd);
        k++;
    }
    return 0;
}
void newRtStudy(vector<NumTrainData>& trainData)
{
    int testCount = trainData.size();
    Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    for (int i= 0; i< testCount; i++)
    {
        NumTrainData td = trainData.at(i);
        memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));
        res.at<unsigned int>(i, 0) = td.result;
    }
    /////////////START RT TRAINNING//////////////////
    CvRTrees forest;
    CvMat* var_importance = 0;
    forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
        CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
    forest.save( "new_rtrees.xml" );
}
int newRtPredict()
{
    CvRTrees forest;
    forest.load( "new_rtrees.xml" );
    const char fileName[] = "../res/t10k-images.idx3-ubyte";
    const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
    ifstream lab_ifs(labelFileName, ios_base::binary);
    ifstream ifs(fileName, ios_base::binary);
    if( ifs.fail() == true )
        return -1;
    if( lab_ifs.fail() == true )
        return -1;
    char magicNum[4], ccount[4], crows[4], ccols[4];
    ifs.read(magicNum, sizeof(magicNum));
    ifs.read(ccount, sizeof(ccount));
    ifs.read(crows, sizeof(crows));
    ifs.read(ccols, sizeof(ccols));
    int count, rows, cols;
    swapBuffer(ccount);
    swapBuffer(crows);
    swapBuffer(ccols);
    memcpy(&count, ccount, sizeof(count));
    memcpy(&rows, crows, sizeof(rows));
    memcpy(&cols, ccols, sizeof(cols));
    Mat src = Mat::zeros(rows, cols, CV_8UC1);
    Mat temp = Mat::zeros(8, 8, CV_8UC1);
    Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    Mat img, dst;
    //Just skip label header
    lab_ifs.read(magicNum, sizeof(magicNum));
    lab_ifs.read(ccount, sizeof(ccount));
    char label = 0;
    Scalar templateColor(255, 0, 0);
    NumTrainData rtd;
    int right = 0, error = 0, total = 0;
    int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
    while(ifs.good())
    {
        //Read label
        lab_ifs.read(&label, 1);
        label = label + '0';
        //Read data
        ifs.read((char*)src.data, rows * cols);
        GetROI(src, dst);
        //Too small to watch
        img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
        resize(dst, img, img.size());
        rtd.result = label;
        resize(dst, temp, temp.size());
        //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
        for(int i = 0; i<8; i++)
        {
            for(int j = 0; j<8; j++)
            {
                m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
            }
        }
        if(total >= count)
            break;
        char ret = (char)forest.predict(m);
        if(ret == label)
        {
            right++;
            if(total <= 5000)
                right_1++;
            else
                right_2++;
        }
        else
        {
            error++;
            if(total <= 5000)
                error_1++;
            else
                error_2++;
        }
        total++;
#if(SHOW_PROCESS)
        stringstream ss;
        ss << "Number " << label << ", predict " << ret;
        string text = ss.str();
        putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
        imshow("img", img);
        if(waitKey(0)==27) //ESC to quit
            break;
#endif
    }
    ifs.close();
    lab_ifs.close();
    stringstream ss;
    ss << "Total " << total << ", right " << right <<", error " << error;
    string text = ss.str();
    putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
    imshow("img", img);
    //waitKey(0);
    return 0;
}
void newSvmStudy(vector<NumTrainData>& trainData)
{
    int testCount = trainData.size();
    Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
    Mat res = Mat::zeros(testCount, 1, CV_32SC1);
    for (int i= 0; i< testCount; i++)
    {
        NumTrainData td = trainData.at(i);
        memcpy(m.data, td.data, featureLen*sizeof(float));
        normalize(m, m);
        memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
        res.at<unsigned int>(i, 0) = td.result;
    }
    /////////////START SVM TRAINNING//////////////////
    CvSVM svm = CvSVM();
    CvSVMParams param;
    CvTermCriteria criteria;
    criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
    param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
    svm.train(data, res, Mat(), Mat(), param);
    svm.save( "SVM_DATA.xml" );
}
int newSvmPredict()
{
    CvSVM svm = CvSVM();
    svm.load( "SVM_DATA.xml" );
    Mat temp = Mat::zeros(8, 8, CV_8UC1);
    Mat m = Mat::zeros(1, featureLen, CV_32FC1);
    Mat dst;
    int k=0;
    const int p_num=17;
    char ch[10];
    while (k <= p_num)
    {
        string str;
        sprintf(ch,"%d",k+313);
        str=ch;
        str=str+".jpg";
        //Read data
        Mat  imread(str.c_str(),  0);
        GetROI(src, dst);
        resize(dst, temp, temp.size());
        //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
        for(int i = 0; i<8; i++)
        {
            for(int j = 0; j<8; j++)
            {
                m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
            }
        }
        normalize(m, m);
        int ret = (int)svm.predict(m);
        cout<<ch<<","<<ret<<endl;
        k++;
    }
    return 0;
}
int main( int argc, char *argv[] )
{
    //ON_STUDY,訓練開關,為1時執行訓練
#if(ON_STUDY)
    //讀入訓練數據,數據最終存入buffer中
    ReadTrainData();
    //訓練,訓練後的svm保存在SVM_DATA.xml中
    //newRtStudy(buffer);
    newSvmStudy(buffer);
#else
    //分析判斷數字
    //newRtPredict();
    newSvmPredict();
#endif
    return 0;
}



輸入的是標號為0到p_num的圖片和對應的數字,nclass.xml是opencv對mat類型自動存儲的格式,然後手動修改內容,想找個更好的,目前還沒有找到

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved