1、用到的库:matplotlib、numpy 

import matplotlib.pyplot as plt
import numpy as np

    使用平台:pycharm社区版+python3.9.6解释器

2、具体程序:

     数据集:

     

 (1)加载数据

def loadDataSet(fileName):
   dataMat = []
   fr = open(fileName)
   for line in fr.readlines():
      curLine = line.strip().split('\t')
      fltLine = list(map(float, curLine))                
      dataMat.append(fltLine)
   return dataMat

(2)绘制数据集

def plotDataSet(filename):
   dataMat = loadDataSet(filename)                           #加载数据集
   n = len(dataMat)                                          #数据个数
   xcord = []; ycord = []                                    #样本点
   for i in range(n):                                     
      xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1])      
   fig = plt.figure()
   ax = fig.add_subplot(111)                                 #添加subplot
   ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5)   #绘制样本点
   plt.title('DataSet')                                      #绘制title
   plt.xlabel('X')
   plt.show()

(3)由特征切分数据集合

def plotDataSet(filename):
   dataMat = loadDataSet(filename)                               #加载数据集
   n = len(dataMat)                                              #数据个数
   xcord = []; ycord = []                                        #样本点
   for i in range(n):                                     
      xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1])   #样本点
   fig = plt.figure()
   ax = fig.add_subplot(111)                                     #添加subplot
   ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5)       #绘制样本点
   plt.title('DataSet')                                          #绘制title
   plt.xlabel('X')
   plt.show()
注:np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。

 (4)生成叶节点

def regLeaf(dataSet):
   return np.mean(dataSet[:,-1])
注:np.mean() 函数定义:numpy.mean(a, axis, dtype, out,keepdims )  功能:求取均值

(5)误差估计

def regErr(dataSet):
   return np.var(dataSet[:,-1]) * np.shape(dataSet)[0]
注:np.var:计算指定数据(数组元素)沿指定轴(如果有)的方差。

(6)找到数据最佳二元切分方式函数

def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
   import types
   #tolS允许的误差下降值,tolN切分的最少样本数
   tolS = ops[0]; tolN = ops[1]
   #如果当前所有值相等,则退出。
   if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
      return None, leafType(dataSet)
   #统计数据集合的行m和列n
   m, n = np.shape(dataSet)
   #默认最后一个特征为最佳切分特征,计算其误差估计
   S = errType(dataSet)
   #分别为最佳误差,最佳特征切分的索引值,最佳特征值
   bestS = float('inf'); bestIndex = 0; bestValue = 0
   #遍历所有特征列
   for featIndex in range(n - 1):
      #遍历所有特征值
      for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
         #根据特征和特征值切分数据集
         mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
         #如果数据少于tolN,则退出
         if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
         #计算误差估计
         newS = errType(mat0) + errType(mat1)
         #如果误差估计更小,则更新特征索引值和特征值
         if newS < bestS: 
            bestIndex = featIndex
            bestValue = splitVal
            bestS = newS
   #如果误差减少不大则退出
   if (S - bestS) < tolS: 
      return None, leafType(dataSet)
   #根据最佳的切分特征和特征值切分数据集合
   mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
   #如果切分出的数据集很小则退出
   if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
      return None, leafType(dataSet)
   #返回最佳切分特征和特征值
   return bestIndex, bestValue

 (7)构建树

def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
   #选择最佳切分特征和特征值
   feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
   #r如果没有特征,则返回特征值
   if feat == None: return val
   #回归树
   retTree = {}
   retTree['spInd'] = feat
   retTree['spVal'] = val
   #分成左数据集和右数据集
   lSet, rSet = binSplitDataSet(dataSet, feat, val)
   #创建左子树和右子树
   retTree['left'] = createTree(lSet, leafType, errType, ops)
   retTree['right'] = createTree(rSet, leafType, errType, ops)
   return retTree 

 (8)塌陷处理:返回平均值

def getMean(tree):
   if isTree(tree['right']): tree['right'] = getMean(tree['right'])
   if isTree(tree['left']): tree['left'] = getMean(tree['left'])
   return (tree['left'] + tree['right']) / 2.0

(9)后剪枝

def prune(tree, testData):
   #如果测试集为空,则对树进行塌陷处理
   if np.shape(testData)[0] == 0: return getMean(tree)
   #如果有左子树或者右子树,则切分数据集
   if (isTree(tree['right']) or isTree(tree['left'])):
      lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
   #处理左子树
   if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
   #处理右子树
   if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
   #如果当前结点的左右结点为叶结点
   if not isTree(tree['left']) and not isTree(tree['right']):
      lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
      #计算没有合并的误差
      errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'],2)) + np.sum(np.power(rSet[:,-1] - tree['right'],2))
      #计算合并的均值
      treeMean = (tree['left'] + tree['right']) / 2.0
      #计算合并的误差
      errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
      #如果合并的误差小于没有合并的误差,则合并
      if errorMerge < errorNoMerge: 
         # print("merging")
         return treeMean
      else: return tree
   else: return tree

 (10)主函数调用

if __name__ == '__main__':
   print('\n剪枝后:')
   test_filename = 'ex2test.txt'
   test_Data = loadDataSet(test_filename)
   test_Mat = np.mat(test_Data)
   print(prune(tree, test_Mat))

参考:https://download.csdn.net/download/weixin_42709150/11146870

 

 

 

 

 

 

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐