×

决策树算法原理与Python实现

hqy hqy 发表于2025-07-14 01:58:33 浏览1 评论0百度已收录

抢沙发发表评论

决策树作为机器学习中基础且应用广泛的分类算法,以其直观的树形结构和强大的解释性备受青睐。本文将从理论基础出发,结合Python代码实现,深入剖析决策树的核心原理与实战应用。

决策树基础概念

决策树是一种树形结构的分类模型,由内部节点和叶节点组成:

内部节点表示特征判断叶节点表示分类结果

其工作原理类似"二十个问题"游戏,通过层层特征筛选缩小分类范围,最终得到实例的类别归属。在邮件分类场景中,决策树会先判断发件域名,再根据内容关键词进一步分类,展现出清晰的层级决策逻辑。

决策树核心数学原理

信息熵与信息增益

决策树构建算法

决策树通过递归方式构建,核心逻辑如下:

def createBranch(): 决策树递归构建逻辑 if 所有数据分类标签相同: return 类标签 else: 选择信息增益最大的特征 划分数据集 为每个划分子集递归创建分支 return 分支节点

决策树核心代码实现

基础函数实现

计算香农熵

def calcShannonEnt(dataSet): """计算数据集的香农熵""" numEntries = len(dataSet) labelCounts = {} # 统计各类标签出现次数 for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts: labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 # 计算香农熵 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * math.log(prob, 2) return shannonEnt

数据集划分函数

def splitDataSet(dataSet, index, value): """根据特征和值划分数据集""" retDataSet = [] for featVec in dataSet: if featVec[index] == value: # 提取划分后的数据(排除当前特征列) reducedFeatVec = featVec[:index] reducedFeatVec.extend(featVec[index+1:]) retDataSet.append(reducedFeatVec) return retDataSet

最优特征选择

def chooseBestFeatureToSplit(dataSet): """选择信息增益最大的特征""" numFeatures = len(dataSet[0]) - 1 # 特征数量 baseEntropy = calcShannonEnt(dataSet) bestInfoGain, bestFeature = 0.0, -1 # 遍历所有特征计算信息增益 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0.0 # 计算当前特征划分后的熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) # 计算信息增益并更新最优特征 infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature

决策树构建与分类

递归创建决策树

def createTree(dataSet, labels): """递归构建决策树""" classList = [example[-1] for example in dataSet] # 停止条件1:所有标签相同 if classList.count(classList[0]) == len(classList): return classList[0] # 停止条件2:所有特征使用完毕 if len(dataSet[0]) == 1: return majorityCnt(classList) # 选择最优特征并构建树 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel: {}} # 递归处理每个划分子集 del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree( splitDataSet(dataSet, bestFeat, value), subLabels) return myTree

使用决策树进行分类

def classify(inputTree, featLabels, testVec): """使用决策树对测试数据分类""" firstStr = list(inputTree.keys())[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) # 递归遍历决策树 key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec) else: classLabel = valueOfFeat return classLabel

决策树项目实战

案例1:鱼类与非鱼类分类

数据准备

def createDataSet(): """创建鱼类分类数据集""" dataSet = [[1, 1, yes], [1, 1, yes], [1, 0, no], [0, 1, no], [0, 1, no]] labels = [no surfacing, flippers] return dataSet, labels

构建与测试流程

收集数据:使用createDataSet()生成样本准备数据:由于数据已离散化,无需额外处理训练模型:调用createTree构建决策树测试模型:使用classify对新数据分类

案例2:隐形眼镜类型预测

数据解析

# 解析文本文件数据 lenses = [inst.strip().split(\t) for inst in fr.readlines()] lensesLabels = [age, prescript, astigmatic, tearRate]

决策树存储与加载

def storeTree(inputTree, filename): """存储决策树到文件""" import pickle fw = open(filename, wb) pickle.dump(inputTree, fw) fw.close() def grabTree(filename): """从文件加载决策树""" import pickle fr = open(filename, rb) return pickle.load(fr)

决策树算法特点

优点

计算复杂度低,适合大规模数据分类结果直观易懂,具有天然可解释性支持处理缺失值和不相关特征

缺点

容易产生过拟合,需通过剪枝优化对高度非线性数据分类效果有限

适用场景

标称型数据(分类数据)和离散化的数值型数据需要解释分类逻辑的场景数据预处理要求较低的快速建模场景

通过以上代码实现和项目案例,我们可以清晰理解决策树从理论到实践的完整流程。作为基础分类算法,决策树不仅是机器学习入门的重要内容,也是理解集成学习算法(如随机森林)的基础,在数据挖掘领域具有不可替代的地位。