14.决策树的最终构建
2024-04-09 16:16:02  阅读数 8100

前面是做了一轮决策,按照信息论的方式,对各特征做了分析,确定了能够带来最大信息增益(注意是熵减)的特征。但仅这一步是不够的,我们需要继续对叶子节点进行同样的操作,直到完成如下的目标:

[if !supportLists]1)[endif]程序遍历完所有划分数据集的属性;

[if !supportLists]2)[endif]每个分支下的所有实例都具有相同的分类;

如果程序已经遍历完所有划分数据集的属性,叶子节点下的实例仍然不具备相同的分类,那就采用多数表决的方法(有点像KNN)来决定该叶子节点的分类。

好,上代码。

def majorityCnt(classList):

    classCount={}

    for vote in classList:

        if vote not in classCount.keys(): classCount[vote] = 0

        classCount[vote] += 1

    sortedClassCount=sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return sortedClassCount[0][0]

如上代码就不逐行展开了,其实就是把一个数组中的标签项数一下数,然后找到哪一个标签出现的次数最多,和KNN中相关的排序方式类似。

我们再来看看,整棵树的遍历:

def createTree(dataSet, labels):

    classList = [example[-1] for example in dataSet]

    if classList.count(classList[0]) == len(classList):

        return classList[0]

    if len(dataSet[0]) == 1:

        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)

    bestFeatLabel = labels[bestFeat]

    myTree = {bestFeatLabel:{}}

    del(labels[bestFeat])

    featValues = [example[bestFeat] for example in dataSet]

    uniqueVals = set(featValues)

    for value in uniqueVals:

        subLabels = labels[:]

        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

    return myTree

代码一共16行。这段代码比较关键,而且不算太容易看懂,我们来逐行看一下:

def createTree(dataSet, labels):

#定义函数,dataSet实际上是带了标签值的数据集,labels其实是标签值的意义定义,详见前面的数据集定义

    classList = [example[-1] for example in dataSet]

#classList实际上就是标签值的数组,标签值位于数据集的最后一列

    if classList.count(classList[0]) == len(classList):

#这里做了一个判断,count方法的作用就是数出数组中某个元素值的个数,在这里就是对classList[0]做了计数,当它的数量等同于数组的长度时,说明这个数组里没有别的标签了,即已经分到了标签唯一的状态;按决策树叶子节点是否达到不可分的条件2,已经完成

        return classList[0]

#返回classList[0],即当前叶子节点唯一的标签值

    if len(dataSet[0]) == 1:

        return majorityCnt(classList)

#如果dataSet的长度为1,那就等于是叶子节点中特征值只有1个,这个时候就满足了决策树叶子节点是否达到不可分的条件1,程序遍历完所有划分数据集的属性,这个时候我们要对叶子节点中的标签进行统计,通过多数表决的方法确定并返回这一分支的标签值。

    bestFeat = chooseBestFeatureToSplit(dataSet)

    bestFeatLabel = labels[bestFeat]

#如果不是上述的两种情况,即节点还可分,那么就用chooseBestFeatureToSplit方法找到最佳的特征项,并对节点进行分解。

    myTree = {bestFeatLabel:{}}

#建立一个字典myTree

    del(labels[bestFeat])

#删除已选出的特征

    featValues = [example[bestFeat] for example in dataSet]

#根据最佳特征的下标,从dataSet里取出相关特征的值的数组

    uniqueVals = set(featValues)

#去重,得到该数组中可能的值

    for value in uniqueVals:

#遍历所有可能的值

        subLabels = labels[:]

#复制出一个label数组,这个数组已经删掉了之前的最佳特征项

        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

#根据最优的特征进行分叉,每一个分叉再去进行生成子树的递归操作(这是关键,通过递归遍历生成所有的子节点,并根据那两条要求确定是否终止并直接return)

    return myTree

#返回树(注意这里的返回,可能是一棵子树,因为是通过递归生成了整棵树,只有最开始的调用才是根节点)


好了,至此,这段代码结束。这段代码不算太容易看懂,原理好懂,但是算法理论要想跟代码联系在一起,还是挺复杂的。最终生成的树结构如下:


它是个什么呢?其实就是一开始的最优特征项是“no surfacing”,然后进行分叉,左边由于标签一致所以结束,右边进行再分叉,然后因为特征用完,结束。

一棵有意思的树。