程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> 決策樹完結篇

決策樹完結篇

編輯:C++入門知識

終於看完了決策樹生成,和測試的代碼,感覺還是非常有收獲的,於是總結下決策樹相關的東西,決策樹說白了就是利用事物已知屬性來構建對事物進行判定,劃分數據的方式在前面的文章中已經進行了介紹,這裡就不多說了,因為前面都沒有給出如何利用自己構建的決策樹來對新添加的數據進行測試,所以下面給出決策代碼:

def classify(inputTree,featLabels,testVec): 
    firstStr = list(inputTree.keys())[0] 
    secondDict = inputTree[firstStr] 
    featIndex = featLabels.index(firstStr) 
    for key in secondDict.keys(): 
        if testVec[featIndex] == key: 
            if type(secondDict[key]).__name__=='dict': 
                classLabel = classify(secondDict[key],featLabels,testVec) 
            else:  
                classLabel = secondDict[key] 
    return classLabel 

def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


 

吼吼,這個點單的測試代碼就是完成對給定數據進行分類決策的。其實就是對整棵樹進行遍歷,直到到達葉子節點。

同樣給出程序的運行截圖:

 


當然為了保險起見:我還是給出全部的源碼,方便沒有看前幾篇的童鞋直接對其運行,操作和修改成自己的代碼。


 import math  
import operator 
 
def calcShannonEnt(dataset): 
    numEntries = len(dataset) 
    labelCounts = {} 
    for featVec in dataset: 
        currentLabel = featVec[-1] 
        if currentLabel not in labelCounts.keys(): 
            labelCounts[currentLabel] = 0 
        labelCounts[currentLabel] +=1 
         
    shannonEnt = 0.0 
    for key in labelCounts: 
        prob = float(labelCounts[key])/numEntries 
        shannonEnt -= prob*math.log(prob, 2) 
    return shannonEnt 
     
def CreateDataSet(): 
    dataset = [[1, 1, 'yes' ],  
               [1, 1, 'yes' ],  
               [1, 0, 'no'],  
               [0, 1, 'no'],  
               [0, 1, 'no']] 
    labels = ['no surfacing', 'flippers'] 
    return dataset, labels 
 
def splitDataSet(dataSet, axis, value): 
    retDataSet = [] 
    for featVec in dataSet: 
        if featVec[axis] == value: 
            reducedFeatVec = featVec[:axis] 
            reducedFeatVec.extend(featVec[axis+1:]) 
            retDataSet.append(reducedFeatVec) 
     
    return retDataSet 
 
def chooseBestFeatureToSplit(dataSet): 
    numberFeatures = len(dataSet[0])-1 
    baseEntropy = calcShannonEnt(dataSet) 
    bestInfoGain = 0.0; 
    bestFeature = -1; 
    for i in range(numberFeatures): 
        featList = [example[i] for example in dataSet] 
        print(featList) 
        uniqueVals = set(featList) 
        print(uniqueVals) 
        newEntropy =0.0 
        for value in uniqueVals: 
            subDataSet = splitDataSet(dataSet, i, value) 
            prob = len(subDataSet)/float(len(dataSet)) 
            newEntropy += prob * calcShannonEnt(subDataSet) 
        infoGain = baseEntropy - newEntropy 
        if(infoGain > bestInfoGain): 
            bestInfoGain = infoGain 
            bestFeature = i 
    return bestFeature 
 
def majorityCnt(classList): 
    classCount ={} 
    for vote in classList: 
        if vote not in classCount.keys(): 
            classCount[vote]=0 
        classCount[vote]=1 
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)  
    return sortedClassCount[0][0] 
  
 
def createTree(dataSet, inputlabels): 
    labels=inputlabels[:] 
    classList = [example[-1] for example in dataSet] 
    if classList.count(classList[0])==len(classList): 
        return classList[0] 
    if len(dataSet[0])==1: 
        return majorityCnt(classList) 
    bestFeat = chooseBestFeatureToSplit(dataSet) 
    bestFeatLabel = labels[bestFeat] 
    myTree = {bestFeatLabel:{}} 
    del(labels[bestFeat]) 
    featValues = [example[bestFeat] for example in dataSet] 
    uniqueVals = set(featValues) 
    for value in uniqueVals: 
        subLabels = labels[:] 
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) 
    return myTree 
 
 
 
def classify(inputTree,featLabels,testVec): 
    firstStr = list(inputTree.keys())[0] 
    secondDict = inputTree[firstStr] 
    featIndex = featLabels.index(firstStr) 
    for key in secondDict.keys(): 
        if testVec[featIndex] == key: 
            if type(secondDict[key]).__name__=='dict': 
                classLabel = classify(secondDict[key],featLabels,testVec) 
            else:  
                classLabel = secondDict[key] 
    return classLabel 
 
     
         
myDat,labels = CreateDataSet() 
print(calcShannonEnt(myDat)) 
 
print(splitDataSet(myDat, 1, 1)) 
 
print(chooseBestFeatureToSplit(myDat)) 
 
myTree = createTree(myDat, labels) 
 
print(classify(myTree, labels, [1, 0])) 
print(classify(myTree, labels, [1, 1])) 

import math
import operator

def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] +=1
       
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*math.log(prob, 2)
    return shannonEnt
   
def CreateDataSet():
    dataset = [[1, 1, 'yes' ],
               [1, 1, 'yes' ],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataset, labels

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
   
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    numberFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0;
    bestFeature = -1;
    for i in range(numberFeatures):
        featList = [example[i] for example in dataSet]
        print(featList)
        uniqueVals = set(featList)
        print(uniqueVals)
        newEntropy =0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    classCount ={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]=1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
 

def createTree(dataSet, inputlabels):
    labels=inputlabels[:]
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

 

def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

   
       
myDat,labels = CreateDataSet()
print(calcShannonEnt(myDat))

print(splitDataSet(myDat, 1, 1))

print(chooseBestFeatureToSplit(myDat))

myTree = createTree(myDat, labels)

print(classify(myTree, labels, [1, 0]))
print(classify(myTree, labels, [1, 1]))

吼吼,這樣我們全部的決策樹的東西就實踐完畢了,祝大家學習工作愉快。

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved