这一章主要是介绍决策树的算法思想并最后用决策树来预测隐形眼睛的类型,在python中一般使用字典数据类型来保存决策树,并借助matplotlib注解工具annotations来可视化决策树。
目前的决策树构建算法有:
- ID3
- C4.5
- CART
本章讲解的是ID3算法,其实这几个算法的不同主要是在选取根节点上的衡量标准不一样,ID3采用的是信息增益来衡量,即若某个节点的信息增益最大就把它作为根节点。C4.5采用信息增益率来衡量,信息增益率最大的就把它作为根节点。CART采用基尼指数来衡量,基尼指数是指样本被误分类的概率。
信息增益是指划分数据集前后信息发生的变化。其计算公式为:
信息的定义是:
所有类别所有可能值包含的信息是:
信息增益为:
这里的信息增益写的可能不是特别准确,有疑问的可以看这里
下面是具体的代码,我记得由于Python的版本不同代码有点改动,但具体改了那里我想不起来了,下面贴的肯定是能运行的。
先是决策树的构造:tree.py
# -*- coding: UTF-8 -*-
from math import log
import operator
import pickle
#用于计算香农熵
def clacShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob , 2)
return shannonEnt
def creatDataSet():
dataSet = [[1,1,'yes'],[1,1,'yes'],[0,1,'no'],[0,1,'no'],[1,0,'no']]
labels = ['man','woman']
return dataSet,labels
#返回的是满足featVec[axis]=value的数据集合
def splotDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#选择信息增益最大的作为根节点
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = clacShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splotDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob* clacShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
else:
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def creatTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
#程序的第一个出口:数据只有一个类别
return classList[0]
if len(dataSet[0]) == 1:
#程序的第二个出口:只有一个特征,只能返回类别数最多的类作为该组数据的分类
return majorityCnt(classList)
#选一个信息增益最大的特征作为根节点并找到其对应的类标签,将该类标签加入树中
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
#找出该类标签在数据中的取值并去重
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
subdataSet = splotDataSet(dataSet,bestFeat,value)
myTree[bestFeatLabel][value] = creatTree(subdataSet,subLabels)
return myTree
#测试分类结果
def classify(inputtree,featlabels,testvec):
#firststr就是根节点的值,也就是dict中的第一个Key
firststr = list(inputtree.keys())[0]
seconddict = inputtree[firststr]
#返回的是firststr在featlabels中的索引值
featindex = featlabels.index(firststr)
for key in seconddict.keys():
if testvec[featindex] == key:
if type(seconddict[key]).__name__ == 'dict':
#如果子节点还是树结构的话就递归调用继续分类
classlabel = classify(seconddict[key], featlabels, testvec)
else:
classlabel = seconddict[key]
return classlabel
#存储决策树
def storetree(inputtree,filename):
fw = open(filename,'wb')
#pickle.dump(obj, file, [,protocol]),把obj保存在file中 ,'wb'写入二进制文件
pickle.dump(inputtree,fw)
fw.close()
def grabtree(filename):
fr = open(filename,'rb')
#pickle.load()读取数据的方式,‘rb’读取二进制文件
return pickle.load(fr)
下面是决策树的可视化:treeplotter.py
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle='sawtooth',fc='0.8')
leafNode = dict(boxstyle='round4',fc="0.8")
arrow_args = dict(arrowstyle="<-")
#annotate函数是注解函数,对xy点进行注解
#关于annotate函数的更多细节https://blog.csdn.net/leaf_zizi/article/details/82886755
#https://blog.csdn.net/wizardforcel/article/details/54782628
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords = "axes fraction",
xytext=centerPt,textcoords='axes fraction',
va = 'center',ha = 'center', bbox = nodeType,
arrowprops = arrow_args)
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(mytree,parentPt,nodeTxt):
numLeafs = getNumLeaf(mytree)
numdepth = getTreeDepth(mytree)
firstStr = list(mytree.keys())[0]
cntrPt = (plotTree.xoff + (1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yoff)
plotMidText(cntrPt,parentPt,nodeTxt)
#先绘制根节点
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = mytree[firstStr]
plotTree.yoff = plotTree.yoff -1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),cntrPt,leafNode)
plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
plotTree.yoff = plotTree.yoff + 1.0/plotTree.totalD
def createPlot(inTree):
flg = plt.figure(1,facecolor="white")
flg.clf()
axprops = dict(xticks=[], yticks=[])#显示图像时不显示坐标
createPlot.ax1 = plt.subplot(111,frameon = False, **axprops)
plotTree.totalW = float(getNumLeaf(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xoff,plotTree.yoff = -0.5/plotTree.totalW,1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
#看树有多少叶子节点
def getNumLeaf(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#测试该节点的类型是否为dict,如果是的话则该节点下面还有叶子节点,否则该节点就是叶子节点
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeaf(secondDict[key])
else:
numLeafs += 1
return numLeafs
#看树的深度
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees =[{'man':{0: 'no', 1: {'woman':{0:'no', 1:'yes'}}}},
{'man':{0: 'no', 1: {'woman':{0:{'head':{0:'no',1:'yes'}},1: 'no'}}}}]
return listOfTrees[i]