kd树 python实现_统计学习方法第三章:k近邻法(k-NN),kd树及python实现
完整代码:xjwhhh/LearningMLgithub.com欢迎follow和star欢迎关注公众号:常失眠少年,谢谢。k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。k近邻法假设给定一个训练数据集,其中的实例类别已定。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。因此,k邻近法不具有显式的学习过程。k近邻法实际上利用训
完整代码:xjwhhh/LearningMLgithub.com
欢迎follow和star
欢迎关注公众号:常失眠少年,谢谢。
k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法。
k近邻法假设给定一个训练数据集,其中的实例类别已定。分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决等方式进行预测。因此,k邻近法不具有显式的学习过程。
k近邻法实际上利用训练数据集对特征空间进行划分,并作为其分类的“模型”。
k值的选择,距离度量及分类决策规则是k近邻法的三个基本要素。
下图是k近邻法:
实现k近邻法时,主要考虑的问题是如何对训练数据进行快速k近邻搜索。这点在特征空间的维数大以及训练数据容量大时尤其必要
k近邻法最简单的实现方式是线性扫描。这时要计算输入实例与每一个训练实例的距离,当训练集很大时,计算非常耗时,这种方法是不可行的
为了提高k近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数。其中一种方法就是kd树方法
下图是kd树的构造算法:
下图是kd树的搜索算法:
更具体的解释和证明可以看《统计学习方法》或者其他解释kd树的博文,我在这里不再赘述
下面是python代码实现,使用MINST数据集,构造kd树进行搜索,实现的是最近邻算法,即只搜寻最近的一个实例来决定类别
但有一个问题是运算很慢,我也不得其解,但算法核心部分实现应当是无误的
import pandas as pd
import numpy as np
import cv2
import logging
import time
from math import sqrt
from collections import namedtuple
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
def log(func):
def wrapper(*args, **kwargs):
start_time = time.time()
logging.debug('start %s()' % func.__name__)
ret = func(*args, **kwargs)
end_time = time.time()
logging.debug('end %s(), cost %s seconds' % (func.__name__, end_time - start_time))
return ret
return wrapper
def get_hog_features(trainset):
# 利用opencv获取图像hog特征
features = []
hog = cv2.HOGDescriptor('../hog.xml')
for img in trainset:
img = np.reshape(img, (28, 28))
cv_img = img.astype(np.uint8)
hog_feature = hog.compute(cv_img)
# hog_feature = np.transpose(hog_feature)
features.append(hog_feature)
features = np.array(features)
features = np.reshape(features, (-1, 324))
return features
def predict(test_set, kd_tree):
predict = []
for i in range(len(test_set)):
predict.append(find_nearest(kd_tree, test_set[i]).label)
return np.array(predict)
# 构造kdTree搜索
# 现在的实现是最近邻,
# 问题1:怎么保存每个结点对应的label,现在的实现似乎成功了,但我不确定
# 问题2:速度非常慢
class KdNode(object):
def __init__(self, dom_elt, split, left, right, label):
self.dom_elt = dom_elt # k维向量节点(k维空间中的一个样本点)
self.split = split # 整数(进行分割维度的序号)
self.left = left # 该结点分割超平面左子空间构成的kd-tree
self.right = right # 该结点分割超平面右子空间构成的kd-tree
self.label = label
class KdTree(object):
@log
def __init__(self, data, labels):
k = len(data[0]) # 数据维度
def create_node(split, data_set, labels): # 按第split维划分数据集,创建KdNode
# print(len(data_set))
if (len(data_set) == 0):
return None
sort_index = data_set[:, split].argsort()
data_set = data_set[sort_index]
labels = labels[sort_index]
# print(data_set)
split_pos = len(data_set) // 2
# print(split_pos)
median = data_set[split_pos] # 中位数分割点
label = labels[split_pos]
split_next = (split + 1) % k # cycle coordinates
# 递归的创建kd树
return KdNode(median, split,
create_node(split_next, data_set[:split_pos], labels[:split_pos]), # 创建左子树
create_node(split_next, data_set[split_pos + 1:], labels[split_pos + 1:]), # 创建右子树
label)
self.root = create_node(0, data, labels) # 从第0维分量开始构建kd树,返回根节点
# 定义一个namedtuple,分别存放最近坐标点、最近距离和访问过的节点数
result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited label")
@log
def find_nearest(tree, point):
k = len(point) # 数据维度
def travel(kd_node, target, max_dist):
if kd_node is None:
return result([0] * k, float("inf"), 0, 0) # python中用float("inf")和float("-inf")表示正负无穷
nodes_visited = 1
s = kd_node.split # 进行分割的维度
pivot = kd_node.dom_elt # 进行分割的“轴”
if target[s] <= pivot[s]: # 如果目标点第s维小于分割轴的对应值(目标离左子树更近)
nearer_node = kd_node.left # 下一个访问节点为左子树根节点
further_node = kd_node.right # 同时记录下右子树
else: # 目标离右子树更近
nearer_node = kd_node.right # 下一个访问节点为右子树根节点
further_node = kd_node.left
if (nearer_node is None):
label = 0
else:
label = nearer_node.label
temp1 = travel(nearer_node, target, max_dist) # 进行遍历找到包含目标点的区域
nearest = temp1.nearest_point # 以此叶结点作为“当前最近点”
dist = temp1.nearest_dist # 更新最近距离
nodes_visited += temp1.nodes_visited
if dist < max_dist:
max_dist = dist # 最近点将在以目标点为球心,max_dist为半径的超球体内
temp_dist = abs(pivot[s] - target[s]) # 第s维上目标点与分割超平面的距离
if max_dist < temp_dist: # 判断超球体是否与超平面相交
return result(nearest, dist, nodes_visited, temp1.label) # 不相交则可以直接返回,不用继续判断
# ----------------------------------------------------------------------
# 计算目标点与分割点的欧氏距离
temp_dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(pivot, target)))
if temp_dist < dist: # 如果“更近”
nearest = pivot # 更新最近点
dist = temp_dist # 更新最近距离
max_dist = dist # 更新超球体半径
label = kd_node
# 检查另一个子结点对应的区域是否有更近的点
temp2 = travel(further_node, target, max_dist)
nodes_visited += temp2.nodes_visited
if temp2.nearest_dist < dist: # 如果另一个子结点内存在更近距离
nearest = temp2.nearest_point # 更新最近点
dist = temp2.nearest_dist # 更新最近距离
label = temp2.label
return result(nearest, dist, nodes_visited, label)
return travel(tree.root, point, float("inf")) # 从根节点开始递归
k = 10
if __name__ == '__main__':
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
raw_data = pd.read_csv('../data/train.csv', header=0)
data = raw_data.values
images = data[0:, 1:]
labels = data[:, 0]
features = get_hog_features(images)
# 选取 2/3 数据作为训练集, 1/3 数据作为测试集
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33,random_state=1)
kd_tree = KdTree(train_features, train_labels)
test_predict = predict(test_features, kd_tree)
score = accuracy_score(test_labels, test_predict)
print("The accuracy score is ", score)
水平有限,如有错误,希望指出
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)