UnionFind:快查连通图

并查集UnionFind又叫Disjoint Sets。它的作用是快速更新和查询节点的连通性。

数组简单实现

简单的并查集实现如下:用数组表示祖节点。

主要功能:

  • find(p):O(N)找到p所在连通图的根 实现:不断向父节点查找

  • union(p, q):(N)把p和q连通 实现:把含有p的根接到含有q的根上

  • connected(p, q):O(N)判断p和q是否连通 实现:判断含有p和q的数是否有一样的根节点

  • self.count:O(1)得到连通图的个数

class UF:
    def __init__(self, n):
        self.parent = [x for x in range(n)]
        self.count = n

    def find(self, p):
        while self.parent[p]!=p:
            p = self.parent[p]
        return p

    def union(self, p, q):
        rootp, rootq = self.find(p), self.find(q)
        if rootp==rootq: return
        self.parent[rootp] = rootq
        self.count -= 1
        
    def connected(self, p, q):
        return self.find(p)==self.find(q)

并查集让一些图的查询和改动变得很方便。比如说我们可以用O(1)的时间查询有几个连通量。在处理很多数据(或者数据流)的时候,能让我们快速知道图的连通性。在《图的性质》章节,我们会细看图的连通性。

路径压缩 / 小树接大树

在这里可以有一些优化:在union过程中加入平衡考虑(小树接大树),以及在find过程中顺便加入路径压缩,可以让查找和union更快。具体可以参考此文。这样一来,findunionconnected的时间复杂度都降到了O(logN)。

class UF:
    def __init__(self, n):
        self.parent = [x for x in range(n)]
        self.count = n
        self.size = [1 for _ in range(n)] # size

    def find(self, p):
        while self.parent[p]!=p:
            self.parent[p] = self.find(self.parent[p]) # path compression
            p = self.parent[p]
        return p
    
    def connected(self, p, q):
        return self.find(p)==self.find(q)
    
    def union(self, p, q):
        if self.connected(p, q): return
        
        # tree balancing ==>
        rootp, rootq = self.find(p), self.find(q)
        if self.size[p]<self.size[q]:
            self.parent[rootp] = rootq
            self.size[rootq] += self.size[rootp]
        else:
            self.parent[rootq] = rootp
            self.size[rootp] += self.size[rootq]
        # <== tree balancing
        self.count -= 1

带权重的UnionFind

目前的UnionFind是一个无向的多叉树。有时候要求不仅能求连通性,而且要计算节点之间的连接的权重。这个时候就需要一个带权重的UnionFind,成为一个无环有向图。

下面的参考代码在“查询“操作的“路径压缩“优化中维护权值变化。

class UF(object):
    def __init__(self, n):
        self.parent = defaultdict(int)
        self.weight = defaultdict(float)
        
    def add(self, p):
        if p in self.parent: return
        self.parent[p] = p
        self.weight[p] = 1.0
        
    def find(self, p):
        if self.parent[p] != p:
            origin = self.parent[p]
            self.parent[p] = self.find(self.parent[p]) # path compression
            self.weight[p] *= self.weight[origin]      # and update weight
        return self.parent[p]
    
    def union(self, p, q, val):
        rootp, rootq = self.find(p), self.find(q) # update weights w/ path cmp
        if rootp==rootq: return
        self.parent[rootp] = rootq
        self.weight[rootp] = val*self.weight[q]/self.weight[p]

例题:快速除法

输入:equations = [["a","b"],["b","c"]], values = [2.0,3.0], 
     queries = [["a","c"],["b","a"],["a","e"],["a","a"],["x","x"]]
输出:[6.00000,0.50000,-1.00000,1.00000,-1.00000]
解释:
条件:a / b = 2.0, b / c = 3.0
问题:a / c = ?, b / a = ?, a / e = ?, a / a = ?, x / x = ?
结果:[6.0, 0.5, -1.0, 1.0, -1.0 ]
def calcEquation(self, equations, values, queries):
    uf = UF(100)
    for (p, q), val in zip(equations, values):
        uf.add(p);uf.add(q)
        uf.union(p, q, val)
    
    res = []
    for p, q in queries:
        rootp, rootq = uf.find(p), uf.find(q)
        if not rootp or not rootq or rootp!=rootq:
            res.append(-1.0)
        else:
            res.append(uf.weight[p]/uf.weight[q])
    return res

Last updated

Was this helpful?