程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> 關於C++ >> k近鄰法的C++實現:kd樹

k近鄰法的C++實現:kd樹

編輯:關於C++
1.k近鄰算法的思想   給定一個訓練集,對於新的輸入實例,在訓練集中找到與該實例最近的k個實例,這k個實例中的多數屬於某個類,就把該輸入實例分為這個類。   因為要找到最近的k個實例,所以計算輸入實例與訓練集中實例之間的距離是關鍵!   k近鄰算法最簡單的方法是線性掃描,這時要計算輸入實例與每一個訓練實例的距離,當訓練集很大時,非常耗時,這種方法不可行,為了提高k近鄰的搜索效率,常常考慮使用特殊的存儲結構存儲訓練數據,以減少計算距離的次數,具體方法很多,這裡介紹實現經典的kd樹方法。   2.構造kd樹   kd樹是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構,kd樹是二叉樹。   下面舉例說明:   給定一個二維空間的數據集: T = {(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)},構造一個平衡kd樹。   根結點對應包含數據集T的矩形選擇x(1) 軸,6個數據點的x(1) 坐標的中位數是7,以超平面x(1) = 7將空間分為左右兩個子矩形(子結點) 左矩形以x(2) = 4為中位數分為兩個子矩形 右矩形以x(2) = 6 分為兩個子矩形 如此遞歸,直到兩個子區域沒有實例存在時停止           3.利用kd樹搜索最近鄰   輸入:已構造的kd樹;目標點x;   輸出:x的最近鄰   在kd樹中找出包含目標點x的葉結點:從根結點出發,遞歸的向下訪問kd樹,若目標點x的當前維的坐標小於切分點的坐標,則移動到左子結點,否則移動到右子結點,直到子結點為葉結點為止。 以此葉結點為“當前最近點” 遞歸地向上回退,在每個結點進行以下操作:(a)如果該結點保存的實例點比當前最近點距離目標點更近,則以該實例點為“當前最近點”; (b)當前最近點一定存在於某結點一個子結點對應的區域,檢查該子結點的父結點的另 一子結點對應區域是否有更近的點(即檢查另一子結點對應的區域是否與以目標點為球 心、以目標點與“當前最近點”間的距離為半徑的球體相交);如果相交,可能在另一 個子結點對應的區域內存在距目標點更近的點,移動到另一個子結點,接著遞歸進行最 近鄰搜索;如果不相交,向上回退   當回退到根結點時,搜索結束,最後的“當前最近點”即為x的最近鄰點。 4.C++實現        1 #include <iostream>   2 #include <vector>   3 #include <algorithm>   4 #include <string>   5 #include <cmath>   6 using namespace std;   7    8    9   10   11 struct KdTree{  12     vector<double> root;  13     KdTree* parent;  14     KdTree* leftChild;  15     KdTree* rightChild;  16     //默認構造函數  17     KdTree(){parent = leftChild = rightChild = NULL;}  18     //判斷kd樹是否為空  19     bool isEmpty()  20     {  21         return root.empty();  22     }  23     //判斷kd樹是否只是一個葉子結點  24     bool isLeaf()  25     {  26         return (!root.empty()) &&   27             rightChild == NULL && leftChild == NULL;  28     }  29     //判斷是否是樹的根結點  30     bool isRoot()  31     {  32         return (!isEmpty()) && parent == NULL;  33     }  34     //判斷該子kd樹的根結點是否是其父kd樹的左結點  35     bool isLeft()  36     {  37         return parent->leftChild->root == root;  38     }  39     //判斷該子kd樹的根結點是否是其父kd樹的右結點  40     bool isRight()  41     {  42         return parent->rightChild->root == root;  43     }  44 };  45   46 int data[6][2] = {{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};  47   48 template<typename T>  49 vector<vector<T> > Transpose(vector<vector<T> > Matrix)  50 {  51     unsigned row = Matrix.size();  52     unsigned col = Matrix[0].size();  53     vector<vector<T> > Trans(col,vector<T>(row,0));  54     for (unsigned i = 0; i < col; ++i)  55     {  56         for (unsigned j = 0; j < row; ++j)  57         {  58             Trans[i][j] = Matrix[j][i];  59         }  60     }  61     return Trans;  62 }  63   64 template <typename T>  65 T findMiddleValue(vector<T> vec)  66 {  67     sort(vec.begin(),vec.end());  68     auto pos = vec.size() / 2;  69     return vec[pos];  70 }  71   72   73 //構建kd樹  74 void buildKdTree(KdTree* tree, vector<vector<double> > data, unsigned depth)  75 {  76   77     //樣本的數量  78     unsigned samplesNum = data.size();  79     //終止條件  80     if (samplesNum == 0)  81     {  82         return;  83     }  84     if (samplesNum == 1)  85     {  86         tree->root = data[0];  87         return;  88     }  89     //樣本的維度  90     unsigned k = data[0].size();  91     vector<vector<double> > transData = Transpose(data);  92     //選擇切分屬性  93     unsigned splitAttribute = depth % k;  94     vector<double> splitAttributeValues = transData[splitAttribute];  95     //選擇切分值  96     double splitValue = findMiddleValue(splitAttributeValues);  97     //cout << "splitValue" << splitValue  << endl;  98   99     // 根據選定的切分屬性和切分值,將數據集分為兩個子集 100     vector<vector<double> > subset1; 101     vector<vector<double> > subset2; 102     for (unsigned i = 0; i < samplesNum; ++i) 103     { 104         if (splitAttributeValues[i] == splitValue && tree->root.empty()) 105             tree->root = data[i]; 106         else 107         { 108             if (splitAttributeValues[i] < splitValue) 109                 subset1.push_back(data[i]); 110             else 111                 subset2.push_back(data[i]); 112         } 113     } 114  115     //子集遞歸調用buildKdTree函數 116  117     tree->leftChild = new KdTree; 118     tree->leftChild->parent = tree; 119     tree->rightChild = new KdTree; 120     tree->rightChild->parent = tree; 121     buildKdTree(tree->leftChild, subset1, depth + 1); 122     buildKdTree(tree->rightChild, subset2, depth + 1); 123 } 124  125 //逐層打印kd樹 126 void printKdTree(KdTree *tree, unsigned depth) 127 { 128     for (unsigned i = 0; i < depth; ++i) 129         cout << "\t"; 130              131     for (vector<double>::size_type j = 0; j < tree->root.size(); ++j) 132         cout << tree->root[j] << ","; 133     cout << endl; 134     if (tree->leftChild == NULL && tree->rightChild == NULL )//葉子節點 135         return; 136     else //非葉子節點 137     { 138         if (tree->leftChild != NULL) 139         { 140             for (unsigned i = 0; i < depth + 1; ++i) 141                 cout << "\t"; 142             cout << " left:"; 143             printKdTree(tree->leftChild, depth + 1); 144         } 145              146         cout << endl; 147         if (tree->rightChild != NULL) 148         { 149             for (unsigned i = 0; i < depth + 1; ++i) 150                 cout << "\t"; 151             cout << "right:"; 152             printKdTree(tree->rightChild, depth + 1); 153         } 154         cout << endl; 155     } 156 } 157  158  159 //計算空間中兩個點的距離 160 double measureDistance(vector<double> point1, vector<double> point2, unsigned method) 161 { 162     if (point1.size() != point2.size()) 163     { 164         cerr << "Dimensions don't match!!" ; 165         exit(1); 166     } 167     switch (method) 168     { 169         case 0://歐氏距離 170             { 171                 double res = 0; 172                 for (vector<double>::size_type i = 0; i < point1.size(); ++i) 173                 { 174                     res += pow((point1[i] - point2[i]), 2); 175                 } 176                 return sqrt(res); 177             } 178         case 1://曼哈頓距離 179             { 180                 double res = 0; 181                 for (vector<double>::size_type i = 0; i < point1.size(); ++i) 182                 { 183                     res += abs(point1[i] - point2[i]); 184                 } 185                 return res; 186             } 187         default: 188             { 189                 cerr << "Invalid method!!" << endl; 190                 return -1; 191             } 192     } 193 } 194 //在kd樹tree中搜索目標點goal的最近鄰 195 //輸入:目標點;已構造的kd樹 196 //輸出:目標點的最近鄰 197 vector<double> searchNearestNeighbor(vector<double> goal, KdTree *tree) 198 { 199     /*第一步:在kd樹中找出包含目標點的葉子結點:從根結點出發, 200     遞歸的向下訪問kd樹,若目標點的當前維的坐標小於切分點的 201     坐標,則移動到左子結點,否則移動到右子結點,直到子結點為 202     葉結點為止,以此葉子結點為“當前最近點” 203     */ 204     unsigned k = tree->root.size();//計算出數據的維數 205     unsigned d = 0;//維度初始化為0,即從第1維開始 206     KdTree* currentTree = tree; 207     vector<double> currentNearest = currentTree->root; 208     while(!currentTree->isLeaf()) 209     { 210         unsigned index = d % k;//計算當前維 211         if (currentTree->rightChild->isEmpty() || goal[index] < currentNearest[index]) 212         { 213             currentTree = currentTree->leftChild; 214         } 215         else 216         { 217             currentTree = currentTree->rightChild; 218         } 219         ++d; 220     } 221     currentNearest = currentTree->root; 222  223     /*第二步:遞歸地向上回退, 在每個結點進行如下操作: 224     (a)如果該結點保存的實例比當前最近點距離目標點更近,則以該例點為“當前最近點” 225     (b)當前最近點一定存在於某結點一個子結點對應的區域,檢查該子結點的父結點的另 226     一子結點對應區域是否有更近的點(即檢查另一子結點對應的區域是否與以目標點為球 227     心、以目標點與“當前最近點”間的距離為半徑的球體相交);如果相交,可能在另一 228     個子結點對應的區域內存在距目標點更近的點,移動到另一個子結點,接著遞歸進行最 229     近鄰搜索;如果不相交,向上回退*/ 230  231     //當前最近鄰與目標點的距離 232     double currentDistance = measureDistance(goal, currentNearest, 0); 233  234     //如果當前子kd樹的根結點是其父結點的左孩子,則搜索其父結點的右孩子結點所代表 235     //的區域,反之亦反 236     KdTree* searchDistrict; 237     if (currentTree->isLeft()) 238     { 239         if (currentTree->parent->rightChild == NULL) 240             searchDistrict = currentTree; 241         else 242             searchDistrict = currentTree->parent->rightChild; 243     } 244     else 245     { 246         searchDistrict = currentTree->parent->leftChild; 247     } 248  249     //如果搜索區域對應的子kd樹的根結點不是整個kd樹的根結點,繼續回退搜索 250     while (searchDistrict->parent != NULL) 251     { 252         //搜索區域與目標點的最近距離 253         double districtDistance = abs(goal[(d+1)%k] - searchDistrict->parent->root[(d+1)%k]); 254  255         //如果“搜索區域與目標點的最近距離”比“當前最近鄰與目標點的距離”短,表明搜索 256         //區域內可能存在距離目標點更近的點 257         if (districtDistance < currentDistance )//&& !searchDistrict->isEmpty() 258         { 259  260             double parentDistance = measureDistance(goal, searchDistrict->parent->root, 0); 261  262             if (parentDistance < currentDistance) 263             { 264                 currentDistance = parentDistance; 265                 currentTree = searchDistrict->parent; 266                 currentNearest = currentTree->root; 267             } 268             if (!searchDistrict->isEmpty()) 269             { 270                 double rootDistance = measureDistance(goal, searchDistrict->root, 0); 271                 if (rootDistance < currentDistance) 272                 { 273                     currentDistance = rootDistance; 274                     currentTree = searchDistrict; 275                     currentNearest = currentTree->root; 276                 } 277             } 278             if (searchDistrict->leftChild != NULL) 279             { 280                 double leftDistance = measureDistance(goal, searchDistrict->leftChild->root, 0); 281                 if (leftDistance < currentDistance) 282                 { 283                     currentDistance = leftDistance; 284                     currentTree = searchDistrict; 285                     currentNearest = currentTree->root; 286                 } 287             } 288             if (searchDistrict->rightChild != NULL) 289             { 290                 double rightDistance = measureDistance(goal, searchDistrict->rightChild->root, 0); 291                 if (rightDistance < currentDistance) 292                 { 293                     currentDistance = rightDistance; 294                     currentTree = searchDistrict; 295                     currentNearest = currentTree->root; 296                 } 297             } 298         }//end if 299  300         if (searchDistrict->parent->parent != NULL) 301         { 302             searchDistrict = searchDistrict->parent->isLeft()?  303                             searchDistrict->parent->parent->rightChild: 304                             searchDistrict->parent->parent->leftChild; 305         } 306         else 307         { 308             searchDistrict = searchDistrict->parent; 309         } 310         ++d; 311     }//end while 312     return currentNearest; 313 } 314  315 int main() 316 { 317     vector<vector<double> > train(6, vector<double>(2, 0)); 318     for (unsigned i = 0; i < 6; ++i) 319         for (unsigned j = 0; j < 2; ++j) 320             train[i][j] = data[i][j]; 321  322     KdTree* kdTree = new KdTree; 323     buildKdTree(kdTree, train, 0); 324  325     printKdTree(kdTree, 0); 326  327     vector<double> goal; 328     goal.push_back(3); 329     goal.push_back(4.5); 330     vector<double> nearestNeighbor = searchNearestNeighbor(goal, kdTree); 331     vector<double>::iterator beg = nearestNeighbor.begin(); 332     cout << "The nearest neighbor is: "; 333     while(beg != nearestNeighbor.end()) cout << *beg++ << ","; 334     cout << endl; 335     return 0; 336 }
  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved