之所以把这个算法单独拿出来讲解,是因为这个算法比较有用,很多算法搭配这个算法,实现起来效果会更好,但是这个算法放到哪个章节都显得很突兀,因为这个算法其实是一个机器学习中的聚类算法,所以应该把这个算法和其他几种聚类算法比如kmeans、DBSCAN、层次聚类算法Birch等放在一起写,但是由于精力有限,先单独成一章把这个算法说清楚吧,而且这个算法的亮点就是用在图像处理领域。其他几种聚类算法以后有机会再写机器学习算法时再写吧。

meanshift算法可以实现对图像在色彩层面进行平滑滤波,它可以中和色彩分布相近的颜色,平滑色彩细节,侵蚀掉面积较小的颜色区域。实现效果如下:

我们可以看到meanshift可以很好的平滑掉图像上彩色细节,图像上的颜色都变成一片一片的了,所以用这个算法我们做一些图像特效还是蛮不错的。但是如果只是用于图像特效,那价值太小了。这个算法的价值在于:它搭配canny算法可以得到效果更加好的边缘检测效果;搭配分水岭算法可以得到更好的图像分割效果;搭配轮廓检测函数cv2.findContours()可以得到更好的轮廓效果;搭配直方图可以得到更好的匹配效果,可以用于视频中的运动跟踪了。
小结:meanshift算法的基本功能是实现图像的色彩滤波、但这个算法搭配其他算法可以更好的实现图像分割、前后景提取、视频跟踪等功能。

meanshift算法并不是像我们第7章图像平滑那章里面讲的,都是用各种滤波核去和图像进行卷积运算而实现图像平滑目的的。meanshift算法本质上一种机器学习算法,是机器学习中的聚类算法的思路和原理,是一种无监督的分类算法。就是把图像的所有像素点看成一个个没有标签的样本数据,然后用聚类方法探索这些样本数据的内部规律,把所有样本分成若干个类别,实现聚类。如果我们对机器学习算法中的聚类算法比较熟悉,那就很容易理解meanshift算法。

一、MeanShift算法原理

机器学习中的meanshift算法原理非常简单,如下图所示:

上面这些数据不需要标签,是无监督机器学习,就是自动聚类,不需要人为干预,也不需要事先规定打算把数据分为几类(kmeans就需要事先确定分几类,也就是k这个超参数需要提前指定),我们只要在初始化时随机选取一个样本点,以这个样本点为圆心,设定一个半径,被这个圆套住的样本点,计算这些样本点的均值(也称为质心),然后将圆心移动到这个均值点上,再画圆,再计算被圈住的样本点的均值,再移动圆心到新的均值点上,再画圆,再算均值,再移动圆心。。。如此反复迭代,直到质心不变,或者改变很小,小于你设定的阈值,我们就认为达到收敛状态,就停止迭代,此时被最后一个圆圈套住的样本点我们就认为是一类的,我们把这些样本标记为簇1。
然后我们再重复上面的步骤,再随机选取一个样本点,重复上面的步骤,直到收敛,就又被圈出一类,我们给这些样本标记为簇2。
如此反复。。。直到所有的样本点都被标记上标签,就停止。
然后我们再计算各个簇的质心之间的距离,如果两个簇的质心之间的距离很近,小于你设定的阈值,就合并两个簇为一簇。这样就实现了分类。看下面的动态图演示:

每个黑点都代表每轮迭代时随机选取的样本点,但它们经过迭代后最终都重叠到每一类的中心位置。

二、meanshift算法的公式推导

Mean Shift这个概念最早是由Fukunaga等人于1975年在一篇关于概率密度梯度函数的估计(The Estimation of the Gradient of a Density Function, with Applications in Pattern Recognition)中提出来的,其最初含义正如其名,自动找到最高密度处。但是在以后的很长一段时间内Mean Shift并没有引起人们的注意,直到20年以后,也就是1995年,另外一篇关于Mean Shift的重要文献(Mean shift,mode seeking, and clustering )发表。在这篇文献中,Yizong Cheng对Mean Shift算法进行了两个方面的改进:一是,Yizong Cheng定义了一族核函数,使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同;二是,Yizong Cheng还设定了一个权重系数,使得不同的样本点重要性不一样,这大大扩大了Mean Shift的适用范围。

import numpy as np
import random
DISTANCE_THRESHOLD = 1e-4
CLUSTER_THRESHOLD = 1e-1

def distance(a,b):
    return np.linalg.norm(np.array(a)-np.array(b))

def Gaussian_kernal(distance,sigma):
    return (1/(sigma*np.sqrt(2*np.pi)))*np.exp(-0.5*distance/(sigma**2))

class MeanShift(object):
    def __init__(self,kernal = Gaussian_kernal):
        self.kernal = kernal
    def shift_points(self,center_point,whole_points,Gaussian_sigma):         ##计算center_point点移动后的坐标
        shifting_px = 0.0
        shifting_py = 0.0
        sum_weight = 0.0
        for each_point in whole_points:#遍历每一个点
            dis = distance(center_point,each_point)#计算当前点与中心点的距离
            Gaussian_weight = self.kernal(dis,Gaussian_sigma)#计算当前点距离中心点的高斯权重
            shifting_px += Gaussian_weight * each_point[0]
            shifting_py += Gaussian_weight * each_point[1]
            sum_weight += Gaussian_weight
        shifting_px /= sum_weight          #归一化
        shifting_py /= sum_weight
        return [shifting_px,shifting_py]
    
    #根据shift之后的点坐标shifting_points获得聚类id
    def cluster_points(self,shifting_points):
        clusterID_points = []#用于存放每一个点的类别号
        cluster_id=0#聚类号初始化为0
        cluster_centers = []#聚类中心点
        for i,each_point in enumerate(shifting_points):#遍历处理每一个点
            if i==0:#如果是处理的第一个点
                clusterID_points.append(cluster_id)#将这个点归为初始化的聚类号(0)
                cluster_centers.append(each_point)#将这个点看作聚类中心点
                cluster_id+=1#聚类号加1
            else:#处理的不是第一个点的情况
                for each_center in cluster_centers:#遍历每一个聚类中心点
                    dis = distance(each_center,each_point)#计算当前点与该聚类中心点的距离
                    if dis < CLUSTER_THRESHOLD:#如果距离小于聚类阈值
                        clusterID_points.append(cluster_centers.index(each_center))#就将当前处理的点归为当前中心点同类(聚类号赋值)
                if(len(clusterID_points)<i+1):#如果上面那个for,所有的聚类中心点都没能收纳一个点,说明是时候开拓一个新类了
                    clusterID_points.append(cluster_id)#把当前点置为一个新类,因为此时的cluster_idx以前谁都没用过
                    cluster_centers.append(each_point)#将这个点作为这个这个新聚类的中心点
                    cluster_id+=1#聚类号加1以备后用
        return clusterID_points
        
    #whole_points:输入的所有点
    #Gaussian_sigma:Gaussian核的sigma
    def fit(self,whole_points,Gaussian_sigma):
        shifting_points = np.array(whole_points)
        need_shifting_flag = [True] * np.shape(whole_points)[0]#每一个点初始都标记为需要shifting
        while True:
            distance_max = 0.0
            #每一轮迭代都对每一个点进行处理
            for i in range(0,np.shape(whole_points)[0]):
                if not need_shifting_flag[i]:#如果这个点已经被标记为不需要继续shifting,就continue
                    continue
                shifting_point_init = shifting_points[i].copy()#将初始的第i个点的坐标备份一下
                #shifting_points[i]由第i个点的坐标更新为第i个点移动后的坐标
                shifting_points[i] = self.shift_points(shifting_points[i],whole_points,Gaussian_sigma)
                #计算第i个点移动的距离
                dis = distance(shifting_point_init,shifting_points[i])
                #如果该点移动的距离小于停止阈值,标记need_shifting_flag[i]为False,下一轮迭代对该点不做处理
                need_shifting_flag[i] = dis > DISTANCE_THRESHOLD
                #本轮迭代中最大的距离存储到distance_max中
                distance_max = max(distance_max,dis)
            #如果在一轮迭代中,所有点移动的最大距离都小于停止阈值,就停止迭代
            if(distance_max < DISTANCE_THRESHOLD):
                break
        #根据shift之后的点坐标shift_points获得聚类id
        cluster_class_id = self.cluster_points(shifting_points.tolist())
        return shifting_points,cluster_class_id
        
from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt 


#按照均匀分布随机产生n个颜色,每个颜色都由R、G、B三个分量表示
def colors(n):
    ret = []
    for i in range(n):
        ret.append((random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)))
    return ret

def main():
    centers = [[0, 1], [-1, 2], [1, 2], [-2.5, 2.5], [2.5,2.5], [-4,1], [4,1], [-3,-1], [3,-1], [-2,-3], [2,-3], [0,-4]]#设置一些中心点
    X, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.3)#产生以这些中心点为中心,一定标准差的n个samples

    mean_shifter = MeanShift()
    shifted_points, mean_shift_result = mean_shifter.fit(X, Gaussian_sigma=0.3)#Gaussian核设置为0.5,对X进行mean_shift

    np.set_printoptions(precision=3)
#    print('input: {}'.format(X))
#    print('assined clusters: {}'.format(mean_shift_result))
    color = colors(np.unique(mean_shift_result).size)

    for i in range(len(mean_shift_result)):
        plt.scatter(X[i, 0], X[i, 1], color = color[mean_shift_result[i]])
        plt.scatter(shifted_points[i,0],shifted_points[i,1], color = 'r')
    plt.xlabel("2018.06.13")
    plt.savefig("result_meanshift.png")
    plt.show()

if __name__ == '__main__':
    main()
       

from sklearn.datasets.samples_generator import make_blobs
import matplotlib.pyplot as plt 
#-------------生成一个12堆的数据--------------------------
centers = [[0, 1], [-1, 2], [1, 2], [-2.5, 2.5], [2.5,2.5], [-4,1], [4,1], [-3,-1], [3,-1], [-2,-3], [2,-3], [0,-4]]   #12个中心
x, _ = make_blobs(n_samples=300, centers=centers, cluster_std=0.3, random_state=0)  #随机性也控制住

#----------定义两点之间的距离计算、高斯权重计算、每个点的meanshift计算----------------------
def distance(a, b):   #求两个点之间的欧式距离
    distance = np.linalg.norm(np.array(a)-np.array(b))    #向量a与向量b对应位置元素相减再平方再相加再开方
    return distance

def gaussian_kernal(distance, sigma):    #给一个sigma超参数,计算高斯权重。说明:如果两个点距离distance很远,那么权重就越小,所以这就类似于上面讲原理里面的画圆圈圈
    gaussian_weight = (1/(sigma*np.sqrt(2*np.pi)))*np.exp(-distance/(2*sigma**2))
    return gaussian_weight

def mean_shift(center_point, all_points, gaussian_sigma): #计算所有点的meanshift向量(其实是计算在一定范围内的点,可以理解为圆圈圈住的点,因为离中心点太远的点的高斯权重趋于0了)
    mean_shift_px = 0.0
    mean_shift_py = 0.0
    sum_weight = 0.0
    for point in all_points:
        dis = distance(center_point, point)
        gaussian_weight = gaussian_kernal(dis, gaussian_sigma)
        mean_shift_px += gaussian_weight*point[0]
        mean_shift_py += gaussian_weight*point[1]
        sum_weight += gaussian_weight
    mean_shift_px /= sum_weight
    mean_shift_py /= sum_weight
    return [mean_shift_px, mean_shift_py]   

#----------迭代10次-------------------------
iteration_points = []
initial_points = np.array(x)
iteration = 0
while True:
    new_points = initial_points.copy()
    for i in range(0, 300):
        new_points[i] = mean_shift(center_point = new_points[i], all_points = initial_points, gaussian_sigma=0.3)
    iteration_points.append(new_points)
    initial_points = new_points
    iteration+=1
    if iteration>10:     #先迭代10次看看效果
        break
#--------确定类别号---------------------------
cluster_center=[]
cluster_id = []
initial_id = 0
distance_threshold = 1e-1
for i, point in enumerate(iteration_points[-1]):  #把最后一次迭代的点切出来做循环
    if i==0:
        culsterid.append(initial_id)
        cluster_center.append(point)
    else:
        for j in cluster_center:
            dis = distance(point, j)
            if dis < distance_threshold:
                
        
        
        
        
#----------可视化看看效果--------------
fig, axes = plt.subplots(2,6, figsize=(16,4), dpi=100)
axes[0,0].scatter(x[:,0],x[:,1],s=1)#原图
axes[0,1].scatter(iteration_points[0][:,0],iteration_points[0][:,1],s=1)  
axes[0,2].scatter(iteration_points[1][:,0],iteration_points[1][:,1],s=1)  
axes[0,3].scatter(iteration_points[2][:,0],iteration_points[2][:,1],s=1)  
axes[0,4].scatter(iteration_points[3][:,0],iteration_points[3][:,1],s=1)  
axes[0,5].scatter(iteration_points[4][:,0],iteration_points[4][:,1],s=1)  
axes[1,0].scatter(iteration_points[5][:,0],iteration_points[5][:,1],s=1)  
axes[1,1].scatter(iteration_points[6][:,0],iteration_points[6][:,1],s=1)  
axes[1,2].scatter(iteration_points[7][:,0],iteration_points[7][:,1],s=1)  
axes[1,3].scatter(iteration_points[8][:,0],iteration_points[8][:,1],s=1)  
axes[1,4].scatter(iteration_points[9][:,0],iteration_points[9][:,1],s=1)  
axes[1,5].scatter(iteration_points[10][:,0],iteration_points[10][:,1],s=1)  
plt.show()

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐