【OpenCV】opencv3.0中的SVM訓練 mnist 手寫字體識別,opencvopencv3.0
前言:
SVM(支持向量機)一種訓練分類器的學習方法
mnist 是一個手寫字體圖像數據庫,訓練樣本有60000個,測試樣本有10000個
LibSVM 一個常用的SVM框架
OpenCV3.0 中的ml包含了很多的ML框架接口,就試試了。
詳細的OpenCV文檔:http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html
mnist數據下載:http://yann.lecun.com/exdb/mnist/
LibSVM下載:http://www.csie.ntu.edu.tw/~cjlin/libsvm/
========================我是分割線=============================
訓練的過程大致如下:
1. 讀取mnist訓練集數據
2. 訓練
3. 讀取mnist測試數據,對比預測結果,得到錯誤率
具體實現:
1. mnist給出的數據文件是二進制文件
四個文件,解壓後如下
"train-images.idx3-ubyte" 二進制文件,存儲了頭文件信息以及60000張28*28圖像pixel信息(用於訓練)
"train-labels.idx1-ubyte" 二進制文件,存儲了頭文件信息以及60000張圖像label信息
"t10k-images.idx3-ubyte"二進制文件,存儲了頭文件信息以及10000張28*28圖像pixel信息(用於測試)
"t10k-labels.idx1-ubyte"二進制文件,存儲了頭文件信息以及10000張圖像label信息
因為OpenCV中沒有直接導入MINST數據的文件,所以需要自己寫函數來讀取
首先要知道,MNIST數據的數據格式

IMAGE FILE包含四個int型的頭部數據(magic number,number_of_images, number_of_rows, number_of_columns)
余下的每一個byte表示一個pixel的數據,范圍是0-255(可以在讀入的時候scale到0~1的區間)
LABEL FILE包含兩個int型的頭部數據(magic number, number of items)
余下的每一個byte表示一個label數據,范圍是0-9
注意(第一個坑):MNIST是大端存儲,然而大部分的Intel處理器都是小端存儲,所以對於int、long、float這些多字節的數據類型,就要一個一個byte地翻轉過來,才能正確顯示。

![]()
1 //翻轉
2 int reverseInt(int i) {
3 unsigned char c1, c2, c3, c4;
4
5 c1 = i & 255;
6 c2 = (i >> 8) & 255;
7 c3 = (i >> 16) & 255;
8 c4 = (i >> 24) & 255;
9
10 return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
11 }
View Code
然後讀取MNIST文件,但是它是二進制文件,打開方式
所以不能用
ifstream file(fileName);
而要改成
ifstream file(fileName, ios::binary);
注意(第二個坑):如果用第一條指令來打開文件,不會報錯,但是數據會出現錯誤,頭部數據仍然正確,但是後面的pixel數據大部分都是0,我剛開始沒注意,開始training的時候發現等了很久...真的是很久...(7+ hours)...估計是達到迭代終止的最大次數了,才停下來的
嗯,stack overflow上也有類似的提問:

注意(第三個坑):
training時,IMAGE和LABEL的數據分別都放進一個MAT中存儲,但是只能是CV32_F或者CV32_S的格式,不然會assertion報錯
OPENCV給出的文檔中,例子是這樣的:(但是predict的時候又會要求label的格式是unsigned int)所以...可以設置data的Mat格式為CV_32FC1,label的Mat格式為CV_32SC1

順便地,圖像訓練數據的轉換存儲格式(http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1)

最後,為了驗證讀取數據的正確性,一個有效的辦法就是輸出第一個和最後一個數據(可以輸出打印第一個/最後一個image以及label)
2. 訓練
(此處我是直接對原圖像訓練,並沒有提取任何的特征)
也有人建議這裡應該對圖像做HOG特征提取,再配合label訓練(我還沒試過...不知道效果如何...)

opencv3.0和2.4的SVM接口有不同,基本可以按照以下的格式來執行:
ml::SVM::Params params;
params.svmType = ml::SVM::C_SVC;
params.kernelType = ml::SVM::POLY;
params.gamma = 3;
Ptr<ml::SVM> svm = ml::SVM::create(params);
Mat trainData; // 每行為一個樣本
Mat labels;
svm->train( trainData , ml::ROW_SAMPLE , labels );
// ...
svm->save("....");//文件形式為xml,可以保存在txt或者xml文件中
Ptr<SVM> svm=statModel::load<SVM>("....");
Mat query; // 輸入, 1個通道
Mat res; // 輸出
svm->predict(query, res);
但是要注意,如果報錯的話最好去看opencv3.0的文檔,裡面有函數原型和解釋,我在實際操作的過程中,也做了一些改動
1)設置參數
SVM的參數有很多,但是與C_SVC和RBF有關的就只有gamma和C,所以設置這兩個就好,終止條件設置和默認一樣,由經驗可得(其實是查閱了很多的資料,把gamma設置成0.01,這樣訓練收斂速度會快很多)
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setKernel(SVM::RBF);
svm->setGamma(0.01);
svm->setC(10.0);
svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000,FLT_EPSILON));
svm_type –指定SVM的類型,下面是可能的取值:
CvSVM::C_SVC C類支持向量分類機。 n類分組 (n \geq 2),允許用異常值懲罰因子C進行不完全分類。
CvSVM::NU_SVC \nu類支持向量分類機。n類似然不完全分類的分類器。參數為 \nu 取代C(其值在區間【0,1】中,nu越大,決策邊界越平滑)。
CvSVM::ONE_CLASS 單分類器,所有的訓練數據提取自同一個類裡,然後SVM建立了一個分界線以分割該類在特征空間中所占區域和其它類在特征空間中所占區域。
CvSVM::EPS_SVR \epsilon類支持向量回歸機。訓練集中的特征向量和擬合出來的超平面的距離需要小於p。異常值懲罰因子C被采用。
CvSVM::NU_SVR \nu類支持向量回歸機。 \nu 代替了 p。
kernel_type –SVM的內核類型,下面是可能的取值:
CvSVM::LINEAR 線性內核。沒有任何向映射至高維空間,線性區分(或回歸)在原始特征空間中被完成,這是最快的選擇。K(x_i, x_j) = x_i^T x_j.
CvSVM::POLY 多項式內核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0.
CvSVM::RBF 基於徑向的函數,對於大多數情況都是一個較好的選擇: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0.
CvSVM::SIGMOID Sigmoid函數內核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0).
degree – 內核函數(POLY)的參數degree。
gamma – 內核函數(POLY/ RBF/ SIGMOID)的參數\gamma。
coef0 – 內核函數(POLY/ SIGMOID)的參數coef0。
Cvalue – SVM類型(C_SVC/ EPS_SVR/ NU_SVR)的參數C。
nu – SVM類型(NU_SVC/ ONE_CLASS/ NU_SVR)的參數 \nu。
p – SVM類型(EPS_SVR)的參數 \epsilon。
class_weights – C_SVC中的可選權重,賦給指定的類,乘以C以後變成 class\_weights_i * C。所以這些權重影響不同類別的錯誤分類懲罰項。權重越大,某一類別的誤分類數據的懲罰項就越大。
term_crit – SVM的迭代訓練過程的中止條件,解決部分受約束二次最優問題。您可以指定的公差和/或最大迭代次數。
2)訓練
Mat trainData;
Mat labels;
trainData = read_mnist_image(trainImage);
labels = read_mnist_label(trainLabel);
svm->train(trainData, ROW_SAMPLE, labels);
3)保存
svm->save("mnist_dataset/mnist_svm.xml");
3. 測試,比對結果
(此處的FLT_EPSILON是一個極小的數,1.0 - FLT_EPSILON != 1.0)
Mat testData;
Mat tLabel;
testData = read_mnist_image(testImage);
tLabel = read_mnist_label(testLabel);
float count = 0;
for (int i = 0; i < testData.rows; i++) {
Mat sample = testData.row(i);
float res = svm1->predict(sample);
res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
count += res;
}
cout << "正確的識別個數 count = " << count << endl;
cout << "錯誤率為..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....\n";
這裡沒有使用svm->predict(query, res);
然後就查看了opencv的文檔,當傳入數據是Mat 而不是cvMat時,可以利用predict的返回值(float)來判斷預測是否正確。

運行結果:
1)1000個訓練數據/1000個測試數據

2)2000個訓練數據/2000個測試數據

3)5000個訓練數據/5000個測試數據

4)10000個訓練數據/10000個測試數據

5)60000個訓練數據/10000個測試數據

最後,關於運行時間(在程序正確的前提下,訓練時長和初始的參數設置有關),給出我最的運行結果(1000張圖是11s左右,60000張是1300s ~ 2000s左右)
代碼:

![]()
1 #ifndef MNIST_H
2 #define MNIST_H
3
4 #include <iostream>
5 #include <string>
6 #include <fstream>
7 #include <ctime>
8 #include <opencv2/opencv.hpp>
9
10 using namespace cv;
11 using namespace std;
12
13 //小端存儲轉換
14 int reverseInt(int i);
15
16 //讀取image數據集信息
17 Mat read_mnist_image(const string fileName);
18
19 //讀取label數據集信息
20 Mat read_mnist_label(const string fileName);
21
22 #endif
mnist.h

![]()
1 #include "mnist.h"
2
3 //計時器
4 double cost_time;
5 clock_t start_time;
6 clock_t end_time;
7
8 //測試item個數
9 int testNum = 10000;
10
11 int reverseInt(int i) {
12 unsigned char c1, c2, c3, c4;
13
14 c1 = i & 255;
15 c2 = (i >> 8) & 255;
16 c3 = (i >> 16) & 255;
17 c4 = (i >> 24) & 255;
18
19 return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
20 }
21
22 Mat read_mnist_image(const string fileName) {
23 int magic_number = 0;
24 int number_of_images = 0;
25 int n_rows = 0;
26 int n_cols = 0;
27
28 Mat DataMat;
29
30 ifstream file(fileName, ios::binary);
31 if (file.is_open())
32 {
33 cout << "成功打開圖像集 ... \n";
34
35 file.read((char*)&magic_number, sizeof(magic_number));
36 file.read((char*)&number_of_images, sizeof(number_of_images));
37 file.read((char*)&n_rows, sizeof(n_rows));
38 file.read((char*)&n_cols, sizeof(n_cols));
39 //cout << magic_number << " " << number_of_images << " " << n_rows << " " << n_cols << endl;
40
41 magic_number = reverseInt(magic_number);
42 number_of_images = reverseInt(number_of_images);
43 n_rows = reverseInt(n_rows);
44 n_cols = reverseInt(n_cols);
45 cout << "MAGIC NUMBER = " << magic_number
46 << " ;NUMBER OF IMAGES = " << number_of_images
47 << " ; NUMBER OF ROWS = " << n_rows
48 << " ; NUMBER OF COLS = " << n_cols << endl;
49
50 //-test-
51 //number_of_images = testNum;
52 //輸出第一張和最後一張圖,檢測讀取數據無誤
53 Mat s = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
54 Mat e = Mat::zeros(n_rows, n_rows * n_cols, CV_32FC1);
55
56 cout << "開始讀取Image數據......\n";
57 start_time = clock();
58 DataMat = Mat::zeros(number_of_images, n_rows * n_cols, CV_32FC1);
59 for (int i = 0; i < number_of_images; i++) {
60 for (int j = 0; j < n_rows * n_cols; j++) {
61 unsigned char temp = 0;
62 file.read((char*)&temp, sizeof(temp));
63 float pixel_value = float((temp + 0.0) / 255.0);
64 DataMat.at<float>(i, j) = pixel_value;
65
66 //打印第一張和最後一張圖像數據
67 if (i == 0) {
68 s.at<float>(j / n_cols, j % n_cols) = pixel_value;
69 }
70 else if (i == number_of_images - 1) {
71 e.at<float>(j / n_cols, j % n_cols) = pixel_value;
72 }
73 }
74 }
75 end_time = clock();
76 cost_time = (end_time - start_time) / CLOCKS_PER_SEC;
77 cout << "讀取Image數據完畢......" << cost_time << "s\n";
78
79 imshow("first image", s);
80 imshow("last image", e);
81 waitKey(0);
82 }
83 file.close();
84 return DataMat;
85 }
86
87 Mat read_mnist_label(const string fileName) {
88 int magic_number;
89 int number_of_items;
90
91 Mat LabelMat;
92
93 ifstream file(fileName, ios::binary);
94 if (file.is_open())
95 {
96 cout << "成功打開Label集 ... \n";
97
98 file.read((char*)&magic_number, sizeof(magic_number));
99 file.read((char*)&number_of_items, sizeof(number_of_items));
100 magic_number = reverseInt(magic_number);
101 number_of_items = reverseInt(number_of_items);
102
103 cout << "MAGIC NUMBER = " << magic_number << " ; NUMBER OF ITEMS = " << number_of_items << endl;
104
105 //-test-
106 //number_of_items = testNum;
107 //記錄第一個label和最後一個label
108 unsigned int s = 0, e = 0;
109
110 cout << "開始讀取Label數據......\n";
111 start_time = clock();
112 LabelMat = Mat::zeros(number_of_items, 1, CV_32SC1);
113 for (int i = 0; i < number_of_items; i++) {
114 unsigned char temp = 0;
115 file.read((char*)&temp, sizeof(temp));
116 LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
117
118 //打印第一個和最後一個label
119 if (i == 0) s = (unsigned int)temp;
120 else if (i == number_of_items - 1) e = (unsigned int)temp;
121 }
122 end_time = clock();
123 cost_time = (end_time - start_time) / CLOCKS_PER_SEC;
124 cout << "讀取Label數據完畢......" << cost_time << "s\n";
125
126 cout << "first label = " << s << endl;
127 cout << "last label = " << e << endl;
128 }
129 file.close();
130 return LabelMat;
131 }
mnist.cpp

![]()
1 /*
2 svm_type –
3 指定SVM的類型,下面是可能的取值:
4 CvSVM::C_SVC C類支持向量分類機。 n類分組 (n \geq 2),允許用異常值懲罰因子C進行不完全分類。
5 CvSVM::NU_SVC \nu類支持向量分類機。n類似然不完全分類的分類器。參數為 \nu 取代C(其值在區間【0,1】中,nu越大,決策邊界越平滑)。
6 CvSVM::ONE_CLASS 單分類器,所有的訓練數據提取自同一個類裡,然後SVM建立了一個分界線以分割該類在特征空間中所占區域和其它類在特征空間中所占區域。
7 CvSVM::EPS_SVR \epsilon類支持向量回歸機。訓練集中的特征向量和擬合出來的超平面的距離需要小於p。異常值懲罰因子C被采用。
8 CvSVM::NU_SVR \nu類支持向量回歸機。 \nu 代替了 p。
9
10 可從 [LibSVM] 獲取更多細節。
11
12 kernel_type –
13 SVM的內核類型,下面是可能的取值:
14 CvSVM::LINEAR 線性內核。沒有任何向映射至高維空間,線性區分(或回歸)在原始特征空間中被完成,這是最快的選擇。K(x_i, x_j) = x_i^T x_j.
15 CvSVM::POLY 多項式內核: K(x_i, x_j) = (\gamma x_i^T x_j + coef0)^{degree}, \gamma > 0.
16 CvSVM::RBF 基於徑向的函數,對於大多數情況都是一個較好的選擇: K(x_i, x_j) = e^{-\gamma ||x_i - x_j||^2}, \gamma > 0.
17 CvSVM::SIGMOID Sigmoid函數內核:K(x_i, x_j) = \tanh(\gamma x_i^T x_j + coef0).
18
19 degree – 內核函數(POLY)的參數degree。
20
21 gamma – 內核函數(POLY/ RBF/ SIGMOID)的參數\gamma。
22
23 coef0 – 內核函數(POLY/ SIGMOID)的參數coef0。
24
25 Cvalue – SVM類型(C_SVC/ EPS_SVR/ NU_SVR)的參數C。
26
27 nu – SVM類型(NU_SVC/ ONE_CLASS/ NU_SVR)的參數 \nu。
28
29 p – SVM類型(EPS_SVR)的參數 \epsilon。
30
31 class_weights – C_SVC中的可選權重,賦給指定的類,乘以C以後變成 class\_weights_i * C。所以這些權重影響不同類別的錯誤分類懲罰項。權重越大,某一類別的誤分類數據的懲罰項就越大。
32
33 term_crit – SVM的迭代訓練過程的中止條件,解決部分受約束二次最優問題。您可以指定的公差和/或最大迭代次數。
34
35 */
36
37
38 #include "mnist.h"
39
40 #include <opencv2/core.hpp>
41 #include <opencv2/imgproc.hpp>
42 #include "opencv2/imgcodecs.hpp"
43 #include <opencv2/highgui.hpp>
44 #include <opencv2/ml.hpp>
45
46 #include <string>
47 #include <iostream>
48
49 using namespace std;
50 using namespace cv;
51 using namespace cv::ml;
52
53 string trainImage = "mnist_dataset/train-images.idx3-ubyte";
54 string trainLabel = "mnist_dataset/train-labels.idx1-ubyte";
55 string testImage = "mnist_dataset/t10k-images.idx3-ubyte";
56 string testLabel = "mnist_dataset/t10k-labels.idx1-ubyte";
57 //string testImage = "mnist_dataset/train-images.idx3-ubyte";
58 //string testLabel = "mnist_dataset/train-labels.idx1-ubyte";
59
60 //計時器
61 double cost_time_;
62 clock_t start_time_;
63 clock_t end_time_;
64
65 int main()
66 {
67
68 //--------------------- 1. Set up training data ---------------------------------------
69 Mat trainData;
70 Mat labels;
71 trainData = read_mnist_image(trainImage);
72 labels = read_mnist_label(trainLabel);
73
74 cout << trainData.rows << " " << trainData.cols << endl;
75 cout << labels.rows << " " << labels.cols << endl;
76
77 //------------------------ 2. Set up the support vector machines parameters --------------------
78 Ptr<SVM> svm = SVM::create();
79 svm->setType(SVM::C_SVC);
80 svm->setKernel(SVM::RBF);
81 //svm->setDegree(10.0);
82 svm->setGamma(0.01);
83 //svm->setCoef0(1.0);
84 svm->setC(10.0);
85 //svm->setNu(0.5);
86 //svm->setP(0.1);
87 svm->setTermCriteria(TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
88
89 //------------------------ 3. Train the svm ----------------------------------------------------
90 cout << "Starting training process" << endl;
91 start_time_ = clock();
92 svm->train(trainData, ROW_SAMPLE, labels);
93 end_time_ = clock();
94 cost_time_ = (end_time_ - start_time_) / CLOCKS_PER_SEC;
95 cout << "Finished training process...cost " << cost_time_ << " seconds..." << endl;
96
97 //------------------------ 4. save the svm ----------------------------------------------------
98 svm->save("mnist_dataset/mnist_svm.xml");
99 cout << "save as /mnist_dataset/mnist_svm.xml" << endl;
100
101
102 //------------------------ 5. load the svm ----------------------------------------------------
103 cout << "開始導入SVM文件...\n";
104 Ptr<SVM> svm1 = StatModel::load<SVM>("mnist_dataset/mnist_svm.xml");
105 cout << "成功導入SVM文件...\n";
106
107
108 //------------------------ 6. read the test dataset -------------------------------------------
109 cout << "開始導入測試數據...\n";
110 Mat testData;
111 Mat tLabel;
112 testData = read_mnist_image(testImage);
113 tLabel = read_mnist_label(testLabel);
114 cout << "成功導入測試數據!!!\n";
115
116
117 float count = 0;
118 for (int i = 0; i < testData.rows; i++) {
119 Mat sample = testData.row(i);
120 float res = svm1->predict(sample);
121 res = std::abs(res - tLabel.at<unsigned int>(i, 0)) <= FLT_EPSILON ? 1.f : 0.f;
122 count += res;
123 }
124 cout << "正確的識別個數 count = " << count << endl;
125 cout << "錯誤率為..." << (10000 - count + 0.0) / 10000 * 100.0 << "%....\n";
126
127 system("pause");
128 return 0;
129 }
main.cpp
一些網站(資料):(其實都很容易搜索到的=_=, 但是搬了人家的東西,就還是貼一下...
http://blog.csdn.net/augusdi/article/details/9005352
http://blog.csdn.net/arthur503/article/details/19974057
http://blog.csdn.net/laihonghuan/article/details/49387237
http://docs.opencv.org/3.0-beta/modules/ml/doc/support_vector_machines.html#prediction-with-svm
http://stackoverflow.com/questions/14694810/using-opencv-and-svm-with-images?rq=1
http://docs.opencv.org/2.4/modules/ml/doc/support_vector_machines.html#cvsvm-train
http://blog.csdn.net/u010869312/article/details/44927721
http://blog.csdn.net/heroacool/article/details/50579955
http://docs.opencv.org/3.0-beta/doc/tutorials/ml/introduction_to_svm/introduction_to_svm.html
http://guyvercz.blog.163.com/blog/static/252545292011112974915402/
http://stackoverflow.com/questions/12993941/how-can-i-read-the-mnist-dataset-with-c?lq=1