Heap:贪心算法助手

Greedy Gas Fill

Alice和Bob要去roadtrip,问题来了:因为路途有点远,我们估计要在路上停下来加油。那么去哪里加油能让我们停下次数最少呢?不同的加油站离出发距离不一样,能加油的量也不一样。就有了这样的对话:

Bob提议:“别一边开一边计划了,我们就一脚油门开到底,不行了再“后悔“去错过的那些能加最多油的地方加。“

Alice说:“那你怎么知道哪些地方油最多呢?“

Bob回答道:“我们在路上拿个小本本记下来就是了“。

这个“小本本“,把油最多的地方记下来的数据结构,就可以是堆(Heap)。

最低加油次数

输入:target = 100, startFuel = 10, stations = [[10,60],[20,30],[30,30],[60,40]]
输出:2
解释:
我们出发时有 10 升燃料。
我们开车来到距起点 10 英里处的加油站,消耗 10 升燃料。将汽油从 0 升加到 60 升。
然后,我们从 10 英里处的加油站开到 60 英里处的加油站(消耗 50 升燃料),
并将汽油从 10 升加到 50 升。然后我们开车抵达目的地。
我们沿途在1两个加油站停靠,所以返回 2

因为python默认最小对,所以要反转一下符号

def minRefuelStops(target, startFuel, stations):
    far = stop = 0
    queue = []
    stations.append((target, float('inf')))
    for dist, fuel in stations:
        startFuel -= (dist-far)       # 一脚油门开到这儿
        while startFuel < 0 and queue:# 不行,得去小本本上油最多的地方加油
            startFuel -= heapq.heappop(queue)
            stop += 1
        if startFuel < 0: return -1
        heapq.heappush(queue, -fuel)  # 拿小本本记下最多可以加多少油
        far = dist
    return stop

二叉堆

堆Heap,又叫优先队列Priority Queue。本质是一棵满二叉树。它把数组中dd[0]空出来,然后每一个数据就很方便的能找到它的左右儿子d[left]=d[i*2]d[right]=d[i*2+1]。一个最小堆的重要性质是每个节点都小于等于它的子节点

二叉堆:来源https://oi-wiki.org/ds/binary-heap/

为啥说堆是贪心算法的助手呢?因为它在堆顶总是存着最大或者最小的那个数字。比如在区间调度中,我们先处理最先来、腾出最先结束的。又比如贪心优化啊的问题中,我们先考虑优化空间最大的(见练习:最大平均分)。类似的情景中,我们都可以用堆来保留这些“最值“。

注意Python中默认的heapq永远是最小堆,如果要把它当作最大堆来用,就在heappush的时候加入val的负值-val。别在heappop的时候忘记把负号反过来就好了:)

heapq库中除了常规的heappushheappop的操作,还几个方便的函数,可以让你少写几行代码:

  • heappushpop:和先push再pop等价

  • heapreplace:和先pop再push等价,不过更高效一些

时间复杂度上,heappushheappop都是O(logn)的复杂度,然而原地建成一个堆的heapifyO(n)的操作(粗略证明)。

还有一些其他的堆类型,在其他操作时有更快的时间效率(WikipediaOI Wiki配对堆等介绍)。

数组简单实现

下面是手动实现的最小堆。最大堆只需把less函数变成more就好了。

主要功能:

  • insert(val)O(logn)插入一个元素 实现:在满二叉树的最后插入一个元素,然后用swim(k)让它游上来

  • pop()O(logn)得到并删除最小元素 实现:得到满二叉树的根值,交换根与尾元素交换然后删除新的尾元素,然后让新的根用sink(k)

    沉下去

  • swim(k):不断地把第k个元素val和父亲比较,如果val小就游上去

  • sink(k):不断地把第k个元素val和小儿子比较,如果val大就沉下去

class MinHeap:
    def __init__(self, cap):
        self.d = [None] * (cap+1)
        self.N = 0

    def insert(self, val):
        self.N += 1
        self.d[self.N] = val  # 加入数据到尾巴
        self.swim(self.N)     # 让它游上来

    def pop(self):
        val = self.d[1]
        self.swap(1, self.N)  # 交换头尾
        self.d[self.N] = None # 删除尾巴
        self.N -= 1
        self.sink(1)          # 新的头数据沉下去
        return val
        
    def swim(self, k):
        while k>1 and self.less(k, self.parent(k)):
            self.swap(k, self.parent(k))
            k = self.parent(k)
    
    def sink(self, k):
        while self.left(k)<=self.N:
            smaller = self.left(k)
            if self.right(k)<=self.N and \
                self.less(self.right(k), smaller):
                smaller = self.right(k)
            if self.less(k, smaller): break
            self.swap(k, smaller)
            k = smaller
    
    def swap(self, i, j):
        self.d[i],self.d[j]=self.d[j],self.d[i]
    def less(self, i, j): # 判断d[i]是否d[j]
        return self.d[i]<self.d[j]
    def more(self, i, j):
        return self.d[j]<self.d[i]
    def left(self, i): return i*2
    def right(self, i): return i*2 + 1
    def parent(self, i): return i//2

经典的例题有:

  • 进程安排:最小堆保存了最早结束时间

  • 最大的K个数:大小为K的最小堆作为缓存,来筛选是否加入新元素

  • 合并K个升序链表[LC 23]:大小为K的最小堆作为缓存排序。延伸:设计推特[LC 355]。

例题:进程安排

下面几题都用优先队列来帮助找到最先结束的进程。

给你一个会议时间安排的数组 intervals ,
每个会议时间都会包括开始和结束的时间 intervals[i] = [starti, endi] ,
为避免会议冲突,同时要考虑充分利用会议室资源.
请你计算至少需要多少间会议室,才能满足这些会议安排。

输入:intervals = [[0,30],[5,10],[15,20]]
输出:2
def minMeetingRooms(intervals):
    if not intervals: return 0
    intervals.sort(key=lambda x: x[0])
    queue = [intervals[0][1]]
    for left, right in intervals[1:]:
        if left < queue[0]:
            heapq.heappush(queue, right)
        else:
            heapq.heapreplace(queue, right)
    return len(queue)

queue = [] # 保存了最大的K个数字
for n in nums:
    if len(queue) < k: # 还没到K个数
        heapq.heappush(queue, n)
    elif n > queue[0]: # 至少你得比K个里面最小的要大
        heapq.heapreplace(queue, n)

一所学校里有一些班级,每个班级里有一些学生,现在每个班都会进行一场期末考试。
给你一个二维数组 classes ,其中 classes[i] = [passi, totali] ,
表示你提前知道了第 i 个班级总共有 totali 个学生,其中只有 passi 个学生可以通过考试。

给你一个整数 extraStudents ,表示额外有 extraStudents 个聪明的学生,
他们 一定 能通过任何班级的期末考。

你需要给这 extraStudents 个学生每人都安排一个班级,使得所有班级的平均通过率最大 。


输入:classes = [[1,2],[3,5],[2,2]], extraStudents = 2
输出:0.78333
解释:你可以将额外的两个学生都安排到第一个班级,
平均通过率为 (3/4 + 3/5 + 2/2) / 3 = 0.78333 。
def maxAverageRatio(classes, extraStudents):
    N = len(classes)
    def dratio(n, m): return n/m - (n+1)/(m+1)

    ratios = [n/m for n, m in classes]
    sratio = sum(ratios)
    queue = []
    for i, (n, m) in enumerate(classes):
        heapq.heappush(queue,
                      (dratio(n, m), n, m, i))

    students = extraStudents
    res = sratio
    while students:
        dr, n, m, i = heapq.heappop(queue)
        nn, nm = n+1, m+1
        r, nr = n/m, (n+1) / (m+1)
        
        heapq.heappush(queue,
                      (dratio(nn,nm), nn, nm, i))
        students -= 1
        
        sratio = sratio - r + nr
        res = max(res, sratio)
    return res / N

例题:车队 II

在一条单车道上有 n 辆车,它们朝着同样的方向行驶。
给你一个长度为 n 的数组 cars ,其中 cars[i] = [positioni, speedi] ,它表示:

- positioni 是第 i 辆车和道路起点之间的距离(单位:米)。
  题目保证 positioni < positioni+1 。
- speedi 是第 i 辆车的初始速度(单位:米/秒)。

简单起见,所有车子可以视为在数轴上移动的点。
当两辆车占据同一个位置时,我们称它们相遇了。
一旦两辆车相遇,它们会合并成一个车队,这个车队里的车有着同样的位置和相同的速度,
速度为这个车队里 最慢 一辆车的速度。

请你返回一个数组 answer ,其中 answer[i] 是第 i 辆车与下一辆车相遇的时间(单位:秒)
如果这辆车不会与下一辆车相遇,则 answer[i] 为 -1 。答案精度误差需在 10-5 以内。


输入:cars = [[1,2],[2,1],[4,3],[7,2]]
输出:[1.00000,-1.00000,3.00000,-1.00000]
解释:经过恰好 1 秒以后,第一辆车会与第二辆车相遇,并形成一个 1 m/s 的车队。
经过恰好 3 秒以后,第三辆车会与第四辆车相遇,并形成一个 2 m/s 的车队。
def getCollisionTimes(cars):
    def bump(i, j):
        car1, car2 = cars[i], cars[j]
        if car1[1] <= car2[1]: return -1
        return (car2[0] - car1[0]) / (car1[1] - car2[1])

    queue = [] # heap of the most recent bump time
    m = {}     # mapping from car to the car it will bump
    res = [-1 for _ in range(len(cars))]
    for i in range(len(cars)-1):
        bTime = bump(i, i+1)
        if bTime != -1: heapq.heappush(queue, (bTime, i, i+1))
        m[i+1] = i
    
    while queue:
        bTime, left, right = heapq.heappop(queue)
        if res[left] != -1: continue
        res[left] = bTime
        if left not in m: continue
        nLeft = m[left]
        m[right] = nLeft
        nTime = bump(nLeft, right)
        if nTime != -1: heapq.heappush(queue, (nTime, nLeft, right))
    return res

Last updated

Was this helpful?