并查集 Union Find 算法

定义

并查集(Disjoint-Set)是一种可以动态维护若干个不重叠的集合,并支持合并查询两种操作的一种数据结构

基本操作

  1. 合并(Union):合并两个集合。
  2. 查询(Find):查询元素所属集合。
    实际操作时,我们会使用一个点来代表整个集合,即一个元素的根结点(可以理解为父亲)。

具体实现

我们建立一个数组father_dict字典表示一个并查集,father_dict[i]表示i的父节点。并且设置一个size_dict字典,size_dict[i]表示i的后代节点的个数,包括其本身。

  1. 初始化:每一个点都是一个集合,因此自己的父节点就是自己father_dict[i]=i,size_dict[i]=1.
  2. 查询:每一个节点不断寻找自己的父节点,若此时自己的父节点就是自己,那么该点为集合的根结点,返回该点。
  3. 合并:合并两个集合只需要合并两个集合的根结点,size_dict大吃小,为了尽可能的降低树高。

路径压缩:实际上,我们在查询过程中只关心根结点是什么,并不关心这棵树的形态(有一些题除外)。因此我们可以在查询操作的时候将访问过的每个点都指向树根,这样的方法叫做路径压缩,单次操作复杂度为O(logN)

路径压缩示例:经过一次FIND节点H,后数据结构的变化。
在这里插入图片描述

代码实现

class UnionFindSet(object):
    """并查集"""
    def __init__(self, data_list):
        """初始化两个字典,一个保存节点的父节点,另外一个保存父节点的大小
        初始化的时候,将节点的父节点设为自身,size设为1"""
        self.father_dict = {}
        self.size_dict = {}

        for node in data_list:
            self.father_dict[node] = node
            self.size_dict[node] = 1

    def find(self, node):
        """使用递归的方式来查找父节点

        在查找父节点的时候,顺便把当前节点移动到父节点上面
        这个操作算是一个优化
        """
        father = self.father_dict[node]
        if(node != father):
            if father != self.father_dict[father]:    # 在降低树高优化时,确保父节点大小字典正确
                self.size_dict[father] -= 1
            father = self.find(father)
        self.father_dict[node] = father
        return father

    def is_same_set(self, node_a, node_b):
        """查看两个节点是不是在一个集合里面"""
        return self.find(node_a) == self.find(node_b)

    def union(self, node_a, node_b):
        """将两个集合合并在一起"""
        if node_a is None or node_b is None:
            return

        a_head = self.find(node_a)
        b_head = self.find(node_b)

        if(a_head != b_head):
            a_set_size = self.size_dict[a_head]
            b_set_size = self.size_dict[b_head]
            if(a_set_size >= b_set_size):
                self.father_dict[b_head] = a_head
                self.size_dict[a_head] = a_set_size + b_set_size
            else:
                self.father_dict[a_head] = b_head
                self.size_dict[b_head] = a_set_size + b_set_size

应用

朋友圈问题

假如已知有n个人和m对好友关系(存于数组r)如果两个人是直接或间接的好友(好友的好友的好友…),则认为他们属于同一个朋友圈,请写程序求出这n个人里一共有多少个朋友圈。假如:n = 5,m = 3,r = {{1 , 2} , {2 , 3} , {4 , 5}}表示有5个人,1和2是好友,2和3是好友4和5是好友,则1、2、3属于一个朋友圈4、5属于另一个朋友圈,结果为2个朋友圈。

输入:
1
10
0 1
2 3
4 6
2 5
5 4
1 6
10 11
7 9
8 10
7 11
输出:
2
N = int(input())

for _ in range(N):
    M = int(input())

    nums = []
    maxNum = 0
    for _ in range(M):
        tempTwoNum = list(map(int, input().split()))
        maxNum = max(maxNum, max(tempTwoNum))
        nums.append(tempTwoNum)

    union_find_set = UnionFindSet(list(range(maxNum+1)))
    for i in range(M):
        union_find_set.union(nums[i][0], nums[i][1])


    res_dict = {}
    for i in union_find_set.father_dict:
        rootNode = union_find_set.find(i)
        if  rootNode in res_dict:
            res_dict[rootNode].append(i)
        else:
            res_dict[rootNode] = [i]

    print(res_dict)
    print('朋友圈个数:', len(res_dict.keys()))

亲戚

若某个家族人员过于庞大,要判断两个是否是亲戚,确实还很不容易,现在给出某个亲戚关系图,求任意给出的两个人是否具有亲戚关系。原题链接

class UnionFindSet(object):
    """并查集"""
    def __init__(self, data_list):
        """初始化两个字典,一个保存节点的父节点,另外一个保存父节点的大小
        初始化的时候,将节点的父节点设为自身,size设为1"""
        self.father_dict = {}
        self.size_dict = {}

        for node in data_list:
            self.father_dict[node] = node
            self.size_dict[node] = 1

    def find(self, node):
        """使用递归的方式来查找父节点

        在查找父节点的时候,顺便把当前节点移动到父节点上面
        这个操作算是一个优化
        """
        father = self.father_dict[node]
        if(node != father):
            if father != self.father_dict[father]:    # 在降低树高优化时,确保父节点大小字典正确
                self.size_dict[father] -= 1
            father = self.find(father)
        self.father_dict[node] = father
        return father

    def is_same_set(self, node_a, node_b):
        """查看两个节点是不是在一个集合里面"""
        return self.find(node_a) == self.find(node_b)

    def union(self, node_a, node_b):
        """将两个集合合并在一起"""
        if node_a is None or node_b is None:
            return

        a_head = self.find(node_a)
        b_head = self.find(node_b)

        if(a_head != b_head):
            a_set_size = self.size_dict[a_head]
            b_set_size = self.size_dict[b_head]
            if(a_set_size >= b_set_size):
                self.father_dict[b_head] = a_head
                self.size_dict[a_head] = a_set_size + b_set_size
            else:
                self.father_dict[a_head] = b_head
                self.size_dict[b_head] = a_set_size + b_set_size



N,M,R = map(int, input().split())

uf = UnionFindSet(list(range(1, N+1)))


for _ in range(M):
    num1, num2 = map(int, input().split())
    uf.union(num1, num2)

for _ in range(R):
    num1, num2 = map(int, input().split())
    if uf.is_same_set(num1, num2):
        print('Yes')
    else:
        print('No')

最长连续序列

leetcode 原题链接:https://leetcode-cn.com/problems/longest-consecutive-sequence/

给定一个未排序的整数数组,找出最长连续序列的长度。

要求算法的时间复杂度为 O(n)。

示例:
输入: [100, 4, 200, 1, 3, 2]
输出: 4
解释: 最长连续序列是 [1, 2, 3, 4]。它的长度为 4。

思路:利用哈希字典,从小往大的FIND,与并查集查找异曲同工,时间复杂度O(1).

class Solution:
    def longestConsecutive(self, nums: List[int]) -> int:
        pre_dict = {}
        for i in nums:
            pre_dict[i] = 1

        res = 0
        for i in pre_dict:
            if i - 1 not in pre_dict:
                y = i + 1
                while y in pre_dict:
                    y += 1

                res = max(res, y - i)

        return res

被围绕的区域

给定一个二维的矩阵,包含 ‘X’ 和 ‘O’(字母 O)。
找到所有被 ‘X’ 围绕的区域,并将这些区域里所有的 ‘O’ 用 ‘X’ 填充。

示例:
X X X X
X O O X
X X O X
X O X X
运行你的函数后,矩阵变为:
X X X X
X X X X
X X X X
X O X X

解释:

被围绕的区间不会存在于边界上,换句话说,任何边界上的 ‘O’ 都不会被填充为 ‘X’。 任何不在边界上,或不与边界上的 ‘O’ 相连的 ‘O’ 最终都会被填充为 ‘X’。如果两个元素在水平或垂直方向相邻,则称它们是“相连”的。

class UnionFindSet(object):
    """并查集"""
    def __init__(self, data_list):
        """初始化两个字典,一个保存节点的父节点,另外一个保存父节点的大小
        初始化的时候,将节点的父节点设为自身,size设为1"""
        self.father_dict = {}
        self.size_dict = {}

        for node in data_list:
            self.father_dict[node] = node
            self.size_dict[node] = 1

    def find(self, node):
        """使用递归的方式来查找父节点

        在查找父节点的时候,顺便把当前节点移动到父节点上面
        这个操作算是一个优化
        """
        father = self.father_dict[node]
        if(node != father):
            if father != self.father_dict[father]:    # 在降低树高优化时,确保父节点大小字典正确
                self.size_dict[father] -= 1
            father = self.find(father)
        self.father_dict[node] = father
        return father

    def is_same_set(self, node_a, node_b):
        """查看两个节点是不是在一个集合里面"""
        return self.find(node_a) == self.find(node_b)

    def union(self, node_a, node_b):
        """将两个集合合并在一起"""
        if node_a is None or node_b is None:
            return

        a_head = self.find(node_a)
        b_head = self.find(node_b)

        if(a_head != b_head):
            a_set_size = self.size_dict[a_head]
            b_set_size = self.size_dict[b_head]
            if(a_set_size >= b_set_size):
                self.father_dict[b_head] = a_head
                self.size_dict[a_head] = a_set_size + b_set_size
            else:
                self.father_dict[a_head] = b_head
                self.size_dict[b_head] = a_set_size + b_set_size


class Solution:
    def solve(self, board: List[List[str]]) -> None:
        """
        Do not return anything, modify board in-place instead.
        """
        if not board:
            return
        m, n = len(board), len(board[0])  # row col
        ufs = UnionFindSet(list(range(m*n+1)))


        for i in range(m):
            for j in range(n):
                if board[i][j] == 'O':
                    if i==0 or i==m-1 or j==0 or j==n-1:
                        ufs.union(m*n, i*n+j)
                    else:
                        if board[i+1][j] == 'O':
                            ufs.union((i+1) * n + j, i*n+j)
                        if board[i-1][j] == 'O':
                            ufs.union((i-1) * n + j, i*n+j)
                        if board[i][j-1] == 'O':
                            ufs.union(i * n + (j-1), i*n+j)
                        if board[i][j+1] == 'O':
                            ufs.union(i * n + (j+1), i*n+j)

        for i in range(m):
            for j in range(n):
                rootEdage = ufs.find(m*n)
                if ufs.find(i*n+j) != rootEdage and board[i][j] == 'O':
                    board[i][j] = 'X'
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐