在线/离线算法

在开头,我们先来看看机器学习中在线/离线的概念。虽然没有必要和算法中在线/离线做直接联系,但是我觉得对理解在线/离线算法有一定帮助。

在ML系统中,比如说用推荐系统选择要给用户选择他们感兴趣的内容,我们需要通过了解用户的过往信息和此时的状态/需求(Query),通过算法来找到最好的内容(Item)。简单的概括,这样的ML系统最重要的两部分工作是:

  1. 启动/调整:用过往或实时的数据来启动/调整算法

  2. 应用:用算法来对新的Query推荐Item

这两部在机器学习中分别叫做训练(training period)和运行(prediction period)。

机器学习中的在线/离线算法

在这样的推荐算法中,有在线和离线的区分。无论是ML系统的训练(training)还是运行(prediction),都可以在线和离线执行,也各自有它们的缺点(备注在括号中):

  • 在线训练(online training):用实时的数据来调整ML系统(算法的不稳定)

  • 离线训练(offline training):线下用大批量的数据来训练ML系统(对新的需求不敏感)

  • 在线推荐(online prediction):用实时的数据来做推荐(延时)

  • 离线推荐(offline prediction):线下对可能的query进行推荐(对新的需求不敏感)

在算法题中也一样,对于很多个query,类比于online prediction vs. offline prediction:可以一个一个处理,或者可以整理好了结果以后再返回。

在线算法

经典的在线算法::用historical queries启动算法,对每一个新的query进行求解。

在线算法的情况很多时候是因为数据太大没办法一股放进内存,只能按照先来后到的顺序来一个个处理数据流(data/network stream)再返回答案。如果每一个新的query的结果都与之前的query有关,此时的在线算法又叫强制在线。

一般会用到的数据类型有:

  • Heap来缩小缓存大小

  • Heap,UnionFind,OrderedDict(或者双向链表)来加速查找。

一些经典题目包括:

  • minheap做缓存,找到数据流中K个最大的数 [LC 703],用对顶堆(minheap & maxheap)找到数据流中的中位数 [LC 295]

  • set/disjoint sets/sliding window找数据流中的连续数字 [LC 128]

  • minheaplinear scanning来安排进程 [LC 253]

例题:蓄水池抽样

蓄水池抽样(Reservoir sampling)有效地处理这样的问题:从N个数字中随机选择k个数。我们维护一个大小为k的蓄水池,对每一个新来的第i数字以 i/ki/k 的概率接受并替换蓄水池中的数字。参考代码:

res = S[:k]
for i in range(k, len(S)):
    j = random.randint(1, i):
    if j<=k: res[j-1] = S[i]

例题:数据流中的滑动窗口

给定一个字符串 s ,找出 刚好含有 M个不同字符的最长子串 T。

输入: s = "ababcbcbaaabbdef", M = 2
输出: "baaabb"
  • 原题[类似LC 340]:找到字符串中含有M个不同字符的substring。滑动窗口+Hashmap

  • 数据流:找到数据流中含有M个不同字符的连续数据。由于如果M比较大的话,没办法把整个可能包含M个字符的substring存入buffer,所以我们要修改Hashmap。不再让Hashmap存储每个字符的频率,而存储每个字符最后一次出现的下标。这样一来,如果不同字符超过M个,我们可以让左指针直接移动到第一个distinct字符。如果要继续加速,快速找到第一个distinct字符,我们可以用OrderedDict来实现。

def longest(s, k):
  res = ""
  left = right = 0
  window = {}
  while right < len(s):
    c = s[right]
    right += 1
    if c in window:
      window[c] += 1
    else:
      window[c] = 1

    while len(window) > k:
      d = s[left]
      left += 1
      window[d] -= 1
      if window[d]==0:
        del window[d]
    if right-left > len(res):
      res = s[left:right]
  return res

例题:数据流中的最长连续序列

给定一个未排序的整数数组 nums ,找出数字连续的最长序列(不要求序列元素在原数组中连续)的长度。

输入:nums = [100,4,200,1,3,2]
输出:4
解释:最长数字连续序列是 [1, 2, 3, 4]。它的长度为 4。
  • 原题[LC 128]:找到数字中最长连续序列

  • 数据流:如果数字很多,不能把所有数字存在buffer中。用并查集来找到最大的连通图,其中连通性由“连续数字“定义。

def longestConsecutive(nums):
    res = 0
    s = set(nums)
    for n in nums:
        if n-1 in s: continue
        cur, streak = n, 1
        while cur+1 in s:
            cur += 1
            streak += 1
        res = max(res, streak)
    return res

离线算法

经典的离线算法:整理好Queries,启动算法,再对每一个Query求解。

比如下面这道题,如果用离线算法先对Queries进行排序再最后求解,时间复杂度就能大大降低。

包含每个查询的最小区间 [LC5748]

给你一个二维整数数组 intervals ,其中 intervals[i] = [lefti, righti] 
表示第 i 个区间开始于 left 、结束于 right(包含两侧取值,闭区间)。
区间的 长度 定义为区间中包含的整数数目,更正式地表达是 right - left + 1 。

再给你一个整数数组 queries 。
第 j 个查询的答案是满足 lefti <= queries[j] <= righti 的 长度最小区间 i 的长度 。
如果不存在这样的区间,那么答案是 -1 。

输入:intervals = [[1,4],[2,4],[3,6],[4,4]], queries = [2,3,4,5]
输出:[3,3,1,4]
解释:查询处理如下:
- Query = 2 :区间 [2,4] 是包含 2 的最小区间,答案为 4 - 2 + 1 = 3 。
- Query = 3 :区间 [2,4] 是包含 3 的最小区间,答案为 4 - 2 + 1 = 3 。
- Query = 4 :区间 [4,4] 是包含 4 的最小区间,答案为 4 - 4 + 1 = 1 。
- Query = 5 :区间 [3,6] 是包含 5 的最小区间,答案为 6 - 3 + 1 = 4 。

如果想到每个query都有可能查询任何一个interval覆盖的区域,就要提前处理interval存储很多信息。但如果把QueriesIntervals都先排序以后从左到右依次处理,可以用线性的遍历来完成查询。

参考代码:

def minInterval(intervals, queries):
    order = sorted([i for i in range(len(queries))],
                    key=lambda i: queries[i])
    intervals.sort()
    p = -1
    res = [-1 for _ in range(len(queries))]
    s = SortedSet()
    for i in order:
        q = queries[i]
        while p+1 < len(intervals) and intervals[p+1][0]<=q:
            p += 1
            s.add((intervals[p][1]-intervals[p][0]+1,
                   intervals[p][1]))      # 按照interval的长短排序
        while s and s[0][1] < q: s.pop(0) # 如果和最早的query无关了
                                          # 那和之后的query也无关了
        if s: res[i] = s[0][0]
    return res

Last updated

Was this helpful?