有序结构:SortedContainers

有序容器(List/Set/Dict)

Python中的SortedContainerPython中的SortedContainers类似于Java中的Treemap/Treeset,帮你打理好了O(logn)的写查删。

  • SortedSet:有序集合

  • SortedDict:有序字典

  • SortedList:可以看成允许重复元素的OrderedSet。

SortedDict和Ordered比较:

OrderedDict的键按照插入的先后次序排序,SortedDict的键按照它们的自然大小排序。比如说分别插入m[3]=0, m[2]=1, m[1]=1之后,结果分别为:

  • OrderedDict {3: 0, 2: 1, 1: 1},先来后到

  • SortedDict {1: 1, 2: 1, 3: 0},谁小谁先

注意和OrderedDict不一样的,SortedDict和其他SortedContainers一样,是有序数据结构,所以内部不是用哈希表来实现的,而是用平衡搜索树来实现的。

如果不能用类似的库,要么就要手撸平衡二叉树,要么就只能用List结合bisect来进行写查删。不过记住在List中很多操作受到插入和删除的机能,只能做到O(n)的时间复杂度:比如bisect.insort(List, x)list.remove(x)

二叉树简单实现(不含自平衡)

这里我们跳过自平衡。假设自平衡已经实现了(比如用红黑树),那么主要功能和时间复杂度是:

  • insert(val):O(logN)插入一个元素 实现:从根开始根据val的大小遍历,并在第一个空位新建一个节点

  • delete(val):O(logN)删除一个元素 [LC 450] 实现:从根开始根据val的大小遍历。三种情况:

    • 不存在,则返回

    • 存在,没有左儿子或者右儿子,用左儿子或右儿子或None代替自己。

    • 存在,有左儿子和右儿子,用右儿子代替自己。同时把左儿子插到右儿子的最小节点(leftmost node)的左边。

  • findMin():O(logN)找到最小值 实现:返回最左边的节点值

  • findMax():O(logN)找到最大值 实现:返回最右边的节点值

class Node:
    def __init__(self, val=None):
        self.val = val
        self.left = self.right = None

class Tree:
    """
    A left-skewed binary search tree.
    """
    def __init__(self):
        self.root = None

    # primary functions
    def insert(self, val):
        self.root = self.insertNode(self.root, val)
    
    def delete(self, val):
        self.root = self.deleteNode(self.root, val)

    def print(self):
        return self.printNode(self.root)

    def findMin(self):
        return self.findMinNode(self.root)
    
    def findMax(self):
        return self.findMaxNode(self.root)
    
    # helper functions
    def insertNode(self, n, v):
        if not n: n = Node(v)
        elif n.val >= v: n.left  = self.insertNode(n.left, v)
        elif n.val < v:  n.right = self.insertNode(n.right, v)
        return n
    
    def deleteNode(self, n, v):
        """
        Replace root with right child
        to maintain order, move left chid as left child of
        the leftmost child of right tree (m)
        """
        if not n: return
        if n.val > v: n.left  = self.deleteNode(n.left, v)
        elif n.val < v: n.right = self.deleteNode(n.right, v)
        elif not n.left or not n.right: return n.left or n.right
        else:
          m = n.right
          while m.left: m = m.left
          m.left = n.left
          n = n.right
        return n

    def printNode(self, n):
        if not n: return
        self.printNode(n.left)
        print(n.val)
        self.printNode(n.right)
        
    def findMinNode(self, n):
        if not n: return
        if not n.left: return n.val
        return self.findMinNode(n.left)
    
    def findMaxNode(self, n):
        if not n: return
        if not n.right: return n.val
        return self.findMaxNode(n.right)

"""
Example:

t = Tree()
t.insert(3)
t.insert(4)
t.insert(2)
t.insert(10)
t.insert(2)
t.print() # 2, 2, 3, 4, 10

t.delete(3)
t.print() # 2, 2, 4, 10

print(t.findMax()) # 10
"""

给你一个整数数组 nums 和两个整数 k 和 t 。
请你判断是否存在 两个不同下标 i 和 j,
使得 abs(nums[i] - nums[j]) <= t ,同时又满足 abs(i - j) <= k

输入:nums = [1,2,3,1], k = 3, t = 0
输出:true

参考代码(维护一个长度为k的有序数组):

def containsNearbyAlmostDuplicate(nums, k, t):
    arr = SortedList([])
    for i, n in enumerate(nums):
        j = bisect.bisect_left(arr, n-t)
        if j!=len(arr) and arr[j]<=n+t: 
            return True
        arr.add(n)
        if i>=k: arr.discard(nums[i-k])
    return False

给你一个 m x n 的矩阵 matrix 和一个整数 k ,
找出并返回矩阵内部矩形区域的不超过 k 的最大数值和。

题目数据保证总会存在一个数值和不超过 k 的矩形区域。

输入:matrix = [[1,0,1],[0,-2,3]], k = 2
输出:2

除了用前缀和把二维区域和转化成一维数组,这道题的本质是在一位数组中求连续和<=k的最大值。用一位数组s[i]表示前缀和之后,问题就是:

max(sjsi), given sjsiksmallest si>=sjkmax({s_j-s_i}) \text{, given } s_j-s_i\le k \\ ⇒ \text{smallest } s_i >=sj-k

用有序集合来存储之前看到的s[i],就能用O(nlogn)的算法枚举s[j]二分查找s[i]

参考代码:

def closestDiff(nums, k):
    # look for smallest si s.t. si>=sj-k
    res = float('-inf')
    s = SortedList([0])
    presum = 0
    for n in nums:
        presum += n
        i = s.bisect_left(presum-k)
        if i!=len(s):
            res = max(res, presum-s[i])
        s.add(presum)
    return res

def maxSumSubmatrix(matrix, k):
    m, n = len(matrix), len(matrix[0])
    res = float("-inf")
    for i in range(m):
        row = [0] * n
        for j in range(i, m):
            for c in range(n):
                row[c] += matrix[j][c]
            res = max(res, closestDiff(row))
    return res

在考场里,一排有 N 个座位,分别编号为 0, 1, 2, ..., N-1 。
当学生进入考场后,他必须坐在能够使他与离他最近的人之间的距离达到最大化的座位上。
如果有多个这样的座位,他会坐在编号最小的座位上。
(另外,如果考场里没有人,那么学生就坐在 0 号座位上。)

输入:["ExamRoom","seat","seat","seat","seat","leave","seat"], 
     [[10],[],[],[],[],[4],[]]
输出:[null,0,9,4,2,null,5]
解释:
ExamRoom(10) -> null
seat() -> 0,没有人在考场里,那么学生坐在 0 号座位上。
seat() -> 9,学生最后坐在 9 号座位上。
seat() -> 4,学生最后坐在 4 号座位上。
seat() -> 2,学生最后坐在 2 号座位上。
leave(4) -> null
seat() -> 5,学生最后坐在 5 号座位上。

延伸阅读:Urinal protocol vulnerability

参考代码:

from sortedcontainers import SortedSet

class ExamRoom:
    def __init__(self, N: int):
        self.N = N
        self.opens = SortedSet(
            [(-1, N)], key=cmp_to_key(self.compare))
        self.start = {}
        self.end = {}

    def compare(self, a, b):
        def dist(x):
            start, end = x
            if start==-1: return end
            if end==self.N: return self.N-1-start
            return (end-start) // 2
        distA, distB = dist(a), dist(b)
        if distA==distB: return b[0]-a[0]
        return distA-distB

    def removeInterval(self, interval):
        self.opens.discard(interval)
        self.start.pop(interval[0], None)
        self.end.pop(interval[1], None)
    
    def addInterval(self,interval):
        self.opens.add(interval)
        self.start[interval[0]] = interval
        self.end[interval[1]] = interval

    def seat(self) -> int:
        start, end = self.opens.pop()
        if start == -1: seat = 0
        elif end == self.N: seat = self.N-1
        else: seat = (start + end) >> 1
        self.addInterval((start, seat))
        self.addInterval((seat, end))
        return seat

    def leave(self, p: int) -> None:
        right, left = self.start[p], self.end[p]
        self.removeInterval(left)
        self.removeInterval(right)
        self.addInterval((left[0], right[1]))

给你一个区间数组 intervals ,其中 intervals[i] = [starti, endi] ,
且每个 starti 都 不同 。

区间 i 的 右侧区间 可以记作区间 j ,并满足 startj >= endi ,且 startj 最小化 。

返回一个由每个区间 i 的 右侧区间 的最小起始位置组成的数组。
如果某个区间 i 不存在对应的 右侧区间 ,则下标 i 处的值设为 -1 。

输入:intervals = [[3,4],[2,3],[1,2]]
输出:[-1, 0, 1]
解释:对于 [3,4] ,没有满足条件的“右侧”区间。
对于 [2,3] ,区间[3,4]具有最小的“右”起点;
对于 [1,2] ,区间[2,3]具有最小的“右”起点。

来源:力扣(LeetCode)
链接:https://leetcode-cn.com/problems/find-right-interval
著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

参考代码:

from sortedcontainers import SortedDict

def findRightInterval(intervals):
    m = SortedDict([])
    for i, interval in enumerate(intervals):
        m[interval[0]] = i

    res = []
    for interval in intervals:
        i = m.bisect_left(interval[1])
        if i == len(m): res.append(-1)
        else: res.append(m[m.keys()[i]])
    return res

Last updated

Was this helpful?