机器学习--回归树python代码
1、用到的库:matplotlib、numpyimport matplotlib.pyplot as pltimport numpy as np使用平台:pycharm社区版+python3.9.6解释器2、具体程序:数据集:(1)加载数据def loadDataSet(fileName):dataMat = []fr = open(fileName)for line in fr.readline
·
1、用到的库:matplotlib、numpy
import matplotlib.pyplot as plt
import numpy as np
使用平台:pycharm社区版+python3.9.6解释器
2、具体程序:
数据集:
(1)加载数据
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine))
dataMat.append(fltLine)
return dataMat
(2)绘制数据集
def plotDataSet(filename):
dataMat = loadDataSet(filename) #加载数据集
n = len(dataMat) #数据个数
xcord = []; ycord = [] #样本点
for i in range(n):
xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1])
fig = plt.figure()
ax = fig.add_subplot(111) #添加subplot
ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5) #绘制样本点
plt.title('DataSet') #绘制title
plt.xlabel('X')
plt.show()
(3)由特征切分数据集合
def plotDataSet(filename): dataMat = loadDataSet(filename) #加载数据集 n = len(dataMat) #数据个数 xcord = []; ycord = [] #样本点 for i in range(n): xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1]) #样本点 fig = plt.figure() ax = fig.add_subplot(111) #添加subplot ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5) #绘制样本点 plt.title('DataSet') #绘制title plt.xlabel('X') plt.show() 注:np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。
(4)生成叶节点
def regLeaf(dataSet): return np.mean(dataSet[:,-1]) 注:np.mean() 函数定义:numpy.mean(a, axis, dtype, out,keepdims ) 功能:求取均值
(5)误差估计
def regErr(dataSet): return np.var(dataSet[:,-1]) * np.shape(dataSet)[0] 注:np.var:计算指定数据(数组元素)沿指定轴(如果有)的方差。
(6)找到数据最佳二元切分方式函数
def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
import types
#tolS允许的误差下降值,tolN切分的最少样本数
tolS = ops[0]; tolN = ops[1]
#如果当前所有值相等,则退出。
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
#统计数据集合的行m和列n
m, n = np.shape(dataSet)
#默认最后一个特征为最佳切分特征,计算其误差估计
S = errType(dataSet)
#分别为最佳误差,最佳特征切分的索引值,最佳特征值
bestS = float('inf'); bestIndex = 0; bestValue = 0
#遍历所有特征列
for featIndex in range(n - 1):
#遍历所有特征值
for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
#根据特征和特征值切分数据集
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
#如果数据少于tolN,则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN): continue
#计算误差估计
newS = errType(mat0) + errType(mat1)
#如果误差估计更小,则更新特征索引值和特征值
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#如果误差减少不大则退出
if (S - bestS) < tolS:
return None, leafType(dataSet)
#根据最佳的切分特征和特征值切分数据集合
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
#如果切分出的数据集很小则退出
if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
return None, leafType(dataSet)
#返回最佳切分特征和特征值
return bestIndex, bestValue
(7)构建树
def createTree(dataSet, leafType = regLeaf, errType = regErr, ops = (1, 4)):
#选择最佳切分特征和特征值
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
#r如果没有特征,则返回特征值
if feat == None: return val
#回归树
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
#分成左数据集和右数据集
lSet, rSet = binSplitDataSet(dataSet, feat, val)
#创建左子树和右子树
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
(8)塌陷处理:返回平均值
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
(9)后剪枝
def prune(tree, testData):
#如果测试集为空,则对树进行塌陷处理
if np.shape(testData)[0] == 0: return getMean(tree)
#如果有左子树或者右子树,则切分数据集
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#处理左子树
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
#处理右子树
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#如果当前结点的左右结点为叶结点
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
#计算没有合并的误差
errorNoMerge = np.sum(np.power(lSet[:,-1] - tree['left'],2)) + np.sum(np.power(rSet[:,-1] - tree['right'],2))
#计算合并的均值
treeMean = (tree['left'] + tree['right']) / 2.0
#计算合并的误差
errorMerge = np.sum(np.power(testData[:,-1] - treeMean, 2))
#如果合并的误差小于没有合并的误差,则合并
if errorMerge < errorNoMerge:
# print("merging")
return treeMean
else: return tree
else: return tree
(10)主函数调用
if __name__ == '__main__':
print('\n剪枝后:')
test_filename = 'ex2test.txt'
test_Data = loadDataSet(test_filename)
test_Mat = np.mat(test_Data)
print(prune(tree, test_Mat))
参考:https://download.csdn.net/download/weixin_42709150/11146870
更多推荐
已为社区贡献1条内容
所有评论(0)