# Trie：字符串匹配剪枝

## Autocomplete & Autocorrection

Bob和Alice在设计一个英文输入法产品，能把用户的输入和英文词典中的单词匹配。

**如何找到所有可能匹配的单词？**

> Bob：“我们不如把用户的输入和所有词典中的单词比较一下，然后找出匹配就好了。“
>
> Alice：“这样可以！不过感觉我在翻词典的时候会这样：如果找tea这个单词，先翻到t开头的一页，然后再翻到te开头那一页，里面有tea啊，ted啊之类的，这样比把每一个单词从词典里拿出来比较快多了！“

像这样把所有单词放在一个词典里的数据结构就是字典树，在字典树里查找一个单词模拟了上述翻词典的过程。

**参考代码（Trie的具体实现见后面）：**

```python
def getWords(swipes, trie):
    def helper(i, root, word):
        if i==len(swipes): return
        if root.isWord:
            res.append(word)
            return

        ch = swipes[i]
        if ch in root.children:
            helper(i+1, root.children[ch], word+ch)
        helper(i+1, root, word)

    res = []
    helper(0, trie.root, "")
    return res

words =  ["apple", "boba", "tea", "car"]
swipe = "bnhjkiocikjhanbvcxszarm"
getWords(swipe, Trie(words)) == {'boba', 'car'}

# 回溯写法
def getWords(swipes, trie):
    def helper(i, path, word):
        if i==len(swipes): return
        if path[-1].isWord:
            res.append(word)
            return
        
        ch = swipes[i]
        if ch in path[-1].children:
            path.append(path[-1].children[ch])
            helper(i+1, path, word+ch)
            path.pop()
        return

    res = []
    helper(0, [trie.root], "")
    return res
```

## 前缀树

字典树（Trie）又叫前缀树（Prefix Tree），本质是`Dict:char->Dict`。

当然我们可以把`Dict`抽象成一个`TrieNode`以便我们在Node当中存储其他数据，比如说isWord来表示是不是叶节点：如果是字典的话就表示我是不是形成一个单词。或者如下图所示保存单词的频率，例如tea出现了3次。

![前缀树表示：来源https://en.wikipedia.org/wiki/Trie](https://upload.wikimedia.org/wikipedia/commons/thumb/b/be/Trie_example.svg/250px-Trie_example.svg.png)

## Dict简单实现

下面是Trie用字典的实现。如果不需要在node上存更多比如频率的信息，可以偷懒实现：直接用特殊字符比如`#:True`来表示`isWord=True`。

主要功能：

* `insert(word)`：O(L)插入一个word，L为单词的长度。\
  实现：遍历Dict中含有word字符的节点TrieNode，没有则插入新的节点。标记最后的节点为终止字符`isWord=True`。
* `search(word)`：O(L)查找一个word，L为单词的长度。\
  实现：遍历Dict中含有word字符的节点TrieNode，没有则返回`False`，最后验证是否停留在终止字符返回`isWord`。
* `startsWith(prefix)`：O(L)查找一个prefix，L为前缀的长度。\
  实现：遍历Dict中含有word字符的节点TrieNode，没有则返回`False`，最后返回`True`。

{% tabs %}
{% tab title="字典实现" %}

```python
from collections import defaultdict

class TrieNode:
    def __init__(self):
        self.children = defaultdict(TrieNode)
        self.isWord = False

class Trie:
    def __init__(self, words):
        self.root = TrieNode()
        for w in words: self.insert(w)

    def insert(self, word: str) -> None:
        r = self.root
        for c in word:
            r = r.children[c]
        r.isWord = True

    def search(self, word: str) -> bool:
        r = self.root
        for c in word:
            if c not in r.children:
                return False
            r = r.children[c]
        return r.isWord

    def startsWith(self, prefix: str) -> bool:
        r = self.root
        for c in prefix:
            if c not in r.children:
                return False
            r = r.children[c]
        return True
```

{% endtab %}

{% tab title="偷懒实现" %}

```python
class Trie:
    def __init__(self):
        self.root = {}

    def insert(self, word: str) -> None:
        r = self.root
        for c in word:
            r = r.setdefault(c, {})
        r['#'] = True # 用#来表示isWord=True

    def search(self, word: str) -> bool:
        r = self.root
        for c in word:
            if c not in r: return False
            r = r[c]
        return r.get("#", False)

    def startsWith(self, prefix: str) -> bool:
        r = self.root
        for c in prefix:
            if c not in r: return False
            r = r[c]
        return True
```

{% endtab %}
{% endtabs %}

## 例题：[最大异或值](https://leetcode-cn.com/problems/maximum-xor-of-two-numbers-in-an-array/)

```python
给你一个整数数组 nums ，返回 nums[i] XOR nums[j] 的最大运算结果，
其中 0 ≤ i ≤ j < n 。

进阶：你可以在 O(n) 的时间解决这个问题吗？

输入：nums = [3,10,5,25,2,8]
输出：28
解释：最大运算结果是 5 XOR 25 = 28.
```

把二进制当成字符串，从右向左每一位插入Trie。建完Trie以后再找最大异或值，从右向左找最多 $$xor=1$$ 的儿子，1）找到就向儿子走，2）没找到就向bit走（因为有可能之后还可能找到$$xor=1$$的儿子）。

```python
class TrieNode:
    def __init__(self):
        self.children = defaultdict(TrieNode)

class Trie:
    def __init__(self, nums):
        self.root = TrieNode()
        for n in nums:
            self.insert(n)
    
    def insert(self, n): 
        cur = self.root
        for i in range(31, -1, -1):
            cur = cur.children[(n>>i) & 1]
    
    def searchMax(self, n):
        res, cur = 0, self.root
        for i in range(31, -1, -1): 
            bit = (n>>i) & 1
            if bit ^ 1 in cur.children:
                cur = cur.children[bit ^ 1]
                res += (1<<i)
            else:
                cur = cur.children[bit]
        return res

class Solution:
    def findMaximumXOR(self, nums: List[int]) -> int:
        t = Trie(nums)
        return max([t.searchMax(n) for n in nums])
```
