k近鄰法的C++實現:kd樹
1.k近鄰算法的思想
給定一個訓練集,對於新的輸入實例,在訓練集中找到與該實例最近的k個實例,這k個實例中的多數屬於某個類,就把該輸入實例分為這個類。
因為要找到最近的k個實例,所以計算輸入實例與訓練集中實例之間的距離是關鍵!
k近鄰算法最簡單的方法是線性掃描,這時要計算輸入實例與每一個訓練實例的距離,當訓練集很大時,非常耗時,這種方法不可行,為了提高k近鄰的搜索效率,常常考慮使用特殊的存儲結構存儲訓練數據,以減少計算距離的次數,具體方法很多,這裡介紹實現經典的kd樹方法。
2.構造kd樹
kd樹是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構,kd樹是二叉樹。
下面舉例說明:
給定一個二維空間的數據集: T = {(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)},構造一個平衡kd樹。
根結點對應包含數據集T的矩形選擇x(1) 軸,6個數據點的x(1) 坐標的中位數是7,以超平面x(1) = 7將空間分為左右兩個子矩形(子結點)
左矩形以x(2) = 4為中位數分為兩個子矩形
右矩形以x(2) = 6 分為兩個子矩形
如此遞歸,直到兩個子區域沒有實例存在時停止
3.利用kd樹搜索最近鄰
輸入:已構造的kd樹;目標點x;
輸出:x的最近鄰
在kd樹中找出包含目標點x的葉結點:從根結點出發,遞歸的向下訪問kd樹,若目標點x的當前維的坐標小於切分點的坐標,則移動到左子結點,否則移動到右子結點,直到子結點為葉結點為止。
以此葉結點為“當前最近點”
遞歸地向上回退,在每個結點進行以下操作:(a)如果該結點保存的實例點比當前最近點距離目標點更近,則以該實例點為“當前最近點”;
(b)當前最近點一定存在於某結點一個子結點對應的區域,檢查該子結點的父結點的另
一子結點對應區域是否有更近的點(即檢查另一子結點對應的區域是否與以目標點為球
心、以目標點與“當前最近點”間的距離為半徑的球體相交);如果相交,可能在另一
個子結點對應的區域內存在距目標點更近的點,移動到另一個子結點,接著遞歸進行最
近鄰搜索;如果不相交,向上回退
當回退到根結點時,搜索結束,最後的“當前最近點”即為x的最近鄰點。
4.C++實現
1 #include <iostream>
2 #include <vector>
3 #include <algorithm>
4 #include <string>
5 #include <cmath>
6 using namespace std;
7
8
9
10
11 struct KdTree{
12 vector<double> root;
13 KdTree* parent;
14 KdTree* leftChild;
15 KdTree* rightChild;
16 //默認構造函數
17 KdTree(){parent = leftChild = rightChild = NULL;}
18 //判斷kd樹是否為空
19 bool isEmpty()
20 {
21 return root.empty();
22 }
23 //判斷kd樹是否只是一個葉子結點
24 bool isLeaf()
25 {
26 return (!root.empty()) &&
27 rightChild == NULL && leftChild == NULL;
28 }
29 //判斷是否是樹的根結點
30 bool isRoot()
31 {
32 return (!isEmpty()) && parent == NULL;
33 }
34 //判斷該子kd樹的根結點是否是其父kd樹的左結點
35 bool isLeft()
36 {
37 return parent->leftChild->root == root;
38 }
39 //判斷該子kd樹的根結點是否是其父kd樹的右結點
40 bool isRight()
41 {
42 return parent->rightChild->root == root;
43 }
44 };
45
46 int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};
47
48 template<typename T>
49 vector<vector<T> > Transpose(vector<vector<T> > Matrix)
50 {
51 unsigned row = Matrix.size();
52 unsigned col = Matrix[0].size();
53 vector<vector<T> > Trans(col,vector<T>(row,0));
54 for (unsigned i = 0; i < col; ++i)
55 {
56 for (unsigned j = 0; j < row; ++j)
57 {
58 Trans[i][j] = Matrix[j][i];
59 }
60 }
61 return Trans;
62 }
63
64 template <typename T>
65 T findMiddleValue(vector<T> vec)
66 {
67 sort(vec.begin(),vec.end());
68 auto pos = vec.size() / 2;
69 return vec[pos];
70 }
71
72
73 //構建kd樹
74 void buildKdTree(KdTree* tree, vector<vector<double> > data, unsigned depth)
75 {
76
77 //樣本的數量
78 unsigned samplesNum = data.size();
79 //終止條件
80 if (samplesNum == 0)
81 {
82 return;
83 }
84 if (samplesNum == 1)
85 {
86 tree->root = data[0];
87 return;
88 }
89 //樣本的維度
90 unsigned k = data[0].size();
91 vector<vector<double> > transData = Transpose(data);
92 //選擇切分屬性
93 unsigned splitAttribute = depth % k;
94 vector<double> splitAttributeValues = transData[splitAttribute];
95 //選擇切分值
96 double splitValue = findMiddleValue(splitAttributeValues);
97 //cout << "splitValue" << splitValue << endl;
98
99 // 根據選定的切分屬性和切分值,將數據集分為兩個子集
100 vector<vector<double> > subset1;
101 vector<vector<double> > subset2;
102 for (unsigned i = 0; i < samplesNum; ++i)
103 {
104 if (splitAttributeValues[i] == splitValue && tree->root.empty())
105 tree->root = data[i];
106 else
107 {
108 if (splitAttributeValues[i] < splitValue)
109 subset1.push_back(data[i]);
110 else
111 subset2.push_back(data[i]);
112 }
113 }
114
115 //子集遞歸調用buildKdTree函數
116
117 tree->leftChild = new KdTree;
118 tree->leftChild->parent = tree;
119 tree->rightChild = new KdTree;
120 tree->rightChild->parent = tree;
121 buildKdTree(tree->leftChild, subset1, depth + 1);
122 buildKdTree(tree->rightChild, subset2, depth + 1);
123 }
124
125 //逐層打印kd樹
126 void printKdTree(KdTree *tree, unsigned depth)
127 {
128 for (unsigned i = 0; i < depth; ++i)
129 cout << "\t";
130
131 for (vector<double>::size_type j = 0; j < tree->root.size(); ++j)
132 cout << tree->root[j] << ",";
133 cout << endl;
134 if (tree->leftChild == NULL && tree->rightChild == NULL )//葉子節點
135 return;
136 else //非葉子節點
137 {
138 if (tree->leftChild != NULL)
139 {
140 for (unsigned i = 0; i < depth + 1; ++i)
141 cout << "\t";
142 cout << " left:";
143 printKdTree(tree->leftChild, depth + 1);
144 }
145
146 cout << endl;
147 if (tree->rightChild != NULL)
148 {
149 for (unsigned i = 0; i < depth + 1; ++i)
150 cout << "\t";
151 cout << "right:";
152 printKdTree(tree->rightChild, depth + 1);
153 }
154 cout << endl;
155 }
156 }
157
158
159 //計算空間中兩個點的距離
160 double measureDistance(vector<double> point1, vector<double> point2, unsigned method)
161 {
162 if (point1.size() != point2.size())
163 {
164 cerr << "Dimensions don't match!!" ;
165 exit(1);
166 }
167 switch (method)
168 {
169 case 0://歐氏距離
170 {
171 double res = 0;
172 for (vector<double>::size_type i = 0; i < point1.size(); ++i)
173 {
174 res += pow((point1[i] - point2[i]), 2);
175 }
176 return sqrt(res);
177 }
178 case 1://曼哈頓距離
179 {
180 double res = 0;
181 for (vector<double>::size_type i = 0; i < point1.size(); ++i)
182 {
183 res += abs(point1[i] - point2[i]);
184 }
185 return res;
186 }
187 default:
188 {
189 cerr << "Invalid method!!" << endl;
190 return -1;
191 }
192 }
193 }
194 //在kd樹tree中搜索目標點goal的最近鄰
195 //輸入:目標點;已構造的kd樹
196 //輸出:目標點的最近鄰
197 vector<double> searchNearestNeighbor(vector<double> goal, KdTree *tree)
198 {
199 /*第一步:在kd樹中找出包含目標點的葉子結點:從根結點出發,
200 遞歸的向下訪問kd樹,若目標點的當前維的坐標小於切分點的
201 坐標,則移動到左子結點,否則移動到右子結點,直到子結點為
202 葉結點為止,以此葉子結點為“當前最近點”
203 */
204 unsigned k = tree->root.size();//計算出數據的維數
205 unsigned d = 0;//維度初始化為0,即從第1維開始
206 KdTree* currentTree = tree;
207 vector<double> currentNearest = currentTree->root;
208 while(!currentTree->isLeaf())
209 {
210 unsigned index = d % k;//計算當前維
211 if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index])
212 {
213 currentTree = currentTree->leftChild;
214 }
215 else
216 {
217 currentTree = currentTree->rightChild;
218 }
219 ++d;
220 }
221 currentNearest = currentTree->root;
222
223 /*第二步:遞歸地向上回退, 在每個結點進行如下操作:
224 (a)如果該結點保存的實例比當前最近點距離目標點更近,則以該例點為“當前最近點”
225 (b)當前最近點一定存在於某結點一個子結點對應的區域,檢查該子結點的父結點的另
226 一子結點對應區域是否有更近的點(即檢查另一子結點對應的區域是否與以目標點為球
227 心、以目標點與“當前最近點”間的距離為半徑的球體相交);如果相交,可能在另一
228 個子結點對應的區域內存在距目標點更近的點,移動到另一個子結點,接著遞歸進行最
229 近鄰搜索;如果不相交,向上回退*/
230
231 //當前最近鄰與目標點的距離
232 double currentDistance = measureDistance(goal, currentNearest, 0);
233
234 //如果當前子kd樹的根結點是其父結點的左孩子,則搜索其父結點的右孩子結點所代表
235 //的區域,反之亦反
236 KdTree* searchDistrict;
237 if (currentTree->isLeft())
238 {
239 if (currentTree->parent->rightChild == NULL)
240 searchDistrict = currentTree;
241 else
242 searchDistrict = currentTree->parent->rightChild;
243 }
244 else
245 {
246 searchDistrict = currentTree->parent->leftChild;
247 }
248
249 //如果搜索區域對應的子kd樹的根結點不是整個kd樹的根結點,繼續回退搜索
250 while (searchDistrict->parent != NULL)
251 {
252 //搜索區域與目標點的最近距離
253 double districtDistance = abs(goal[(d+1)%k] - searchDistrict->parent->root[(d+1)%k]);
254
255 //如果“搜索區域與目標點的最近距離”比“當前最近鄰與目標點的距離”短,表明搜索
256 //區域內可能存在距離目標點更近的點
257 if (districtDistance < currentDistance )//&& !searchDistrict->isEmpty()
258 {
259
260 double parentDistance = measureDistance(goal, searchDistrict->parent->root, 0);
261
262 if (parentDistance < currentDistance)
263 {
264 currentDistance = parentDistance;
265 currentTree = searchDistrict->parent;
266 currentNearest = currentTree->root;
267 }
268 if (!searchDistrict->isEmpty())
269 {
270 double rootDistance = measureDistance(goal, searchDistrict->root, 0);
271 if (rootDistance < currentDistance)
272 {
273 currentDistance = rootDistance;
274 currentTree = searchDistrict;
275 currentNearest = currentTree->root;
276 }
277 }
278 if (searchDistrict->leftChild != NULL)
279 {
280 double leftDistance = measureDistance(goal, searchDistrict->leftChild->root, 0);
281 if (leftDistance < currentDistance)
282 {
283 currentDistance = leftDistance;
284 currentTree = searchDistrict;
285 currentNearest = currentTree->root;
286 }
287 }
288 if (searchDistrict->rightChild != NULL)
289 {
290 double rightDistance = measureDistance(goal, searchDistrict->rightChild->root, 0);
291 if (rightDistance < currentDistance)
292 {
293 currentDistance = rightDistance;
294 currentTree = searchDistrict;
295 currentNearest = currentTree->root;
296 }
297 }
298 }//end if
299
300 if (searchDistrict->parent->parent != NULL)
301 {
302 searchDistrict = searchDistrict->parent->isLeft()?
303 searchDistrict->parent->parent->rightChild:
304 searchDistrict->parent->parent->leftChild;
305 }
306 else
307 {
308 searchDistrict = searchDistrict->parent;
309 }
310 ++d;
311 }//end while
312 return currentNearest;
313 }
314
315 int main()
316 {
317 vector<vector<double> > train(6, vector<double>(2, 0));
318 for (unsigned i = 0; i < 6; ++i)
319 for (unsigned j = 0; j < 2; ++j)
320 train[i][j] = data[i][j];
321
322 KdTree* kdTree = new KdTree;
323 buildKdTree(kdTree, train, 0);
324
325 printKdTree(kdTree, 0);
326
327 vector<double> goal;
328 goal.push_back(3);
329 goal.push_back(4.5);
330 vector<double> nearestNeighbor = searchNearestNeighbor(goal, kdTree);
331 vector<double>::iterator beg = nearestNeighbor.begin();
332 cout << "The nearest neighbor is: ";
333 while(beg != nearestNeighbor.end()) cout << *beg++ << ",";
334 cout << endl;
335 return 0;
336 }