目录

Top K 问题

Top K 问题有两种不同的解法,一种解法使用堆(优先队列),另一种解法使用类似快速排序的分治法。

解题思路:

  • 用快排最高效解决 Top K 问题
  • 大根堆(前 K 小) / 小根堆(前 K 大), Java中有现成的 PriorityQueue,实现起来最简单
  • 二叉搜索树也可以 解决 TopK 问题
  • 数据范围有限时直接计数排序就行了

最小的 k 个数

输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。

快排变形

在计算机科学中,快速选择(英语:Quickselect)是一种从无序列表找到第k小元素的选择算法。

Top K 问题的另一个解法就比较难想到,需要在平时有算法的积累。实际上,“查找第 k 大的元素”是一类算法问题,称为选择问题。找第 k 大的数,或者找前 k 大的数,有一个经典的 quick select(快速选择)算法。这个名字和 quick sort(快速排序)看起来很像,算法的思想也和快速排序类似,都是分治法的思想。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
public int[] getLeastNumbers(int[] arr, int k) {
    if (k == 0) {
        return new int[0];
    } else if (arr.length <= k) {
        return arr;
    }
    
    // 原地不断划分数组
    partitionArray(arr, 0, arr.length - 1, k);
    
    // 数组的前 k 个数此时就是最小的 k 个数,将其存入结果
    int[] res = new int[k];
    for (int i = 0; i < k; i++) {
        res[i] = arr[i];
    }
    return res;
}

void partitionArray(int[] arr, int lo, int hi, int k) {
    // 做一次 partition 操作
    int m = partition(arr, lo, hi);
    // 此时数组前 m 个数,就是最小的 m 个数
    if (k == m) {
        // 正好找到最小的 k(m) 个数
        return;
    } else if (k < m) {
        // 最小的 k 个数一定在前 m 个数中,递归划分
        partitionArray(arr, lo, m-1, k);
    } else {
        // 在右侧数组中寻找最小的 k-m 个数
        partitionArray(arr, m+1, hi, k);
    }
}

// partition 函数和快速排序中相同,具体可参考快速排序相关的资料
// 代码参考 Sedgewick 的《算法4》
int partition(int[] a, int lo, int hi) {
    int i = lo;
    int j = hi + 1;
    int v = a[lo];
    while (true) { 
        while (a[++i] < v) {
            if (i == hi) {
                break;
            }
        }
        while (a[--j] > v) {
            if (j == lo) {
                break;
            }
        }

        if (i >= j) {
            break;
        }
        swap(a, i, j);
    }
    swap(a, lo, j);

    // a[lo .. j-1] <= a[j] <= a[j+1 .. hi]
    return j;
}

void swap(int[] a, int i, int j) {
    int temp = a[i];
    a[i] = a[j];
    a[j] = temp;
}

算法的复杂度分析:

  • 空间复杂度 $O(1)$,不需要额外空间。
  • 时间复杂度的分析方法和快速排序类似。由于快速选择只需要递归一边的数组,时间复杂度小于快速排序,期望时间复杂度为 $O(n)$,最坏情况下的时间复杂度为 $O(n^2)$。

大根堆(前 K 小) / 小根堆(前 K 大)

比较直观的想法是使用堆数据结构来辅助得到最小的 k 个数。堆的性质是每次可以找出最大或最小的元素。我们可以使用一个大小为 k 的最大堆(大顶堆),将数组中的元素依次入堆,当堆的大小超过 k 时,便将多出的元素从堆顶弹出。这样,由于每次从堆顶弹出的数都是堆中最大的,最小的 k 个元素一定会留在堆里。

Java 中的 PriorityQueue 就是堆的数据结构

优化,如果当前数字不小于堆顶元素,数字可以直接丢掉,不入堆。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
public int[] getLeastNumbers(int[] arr, int k) {
    if (k == 0) {
        return new int[0];
    }
    // 使用一个最大堆(大顶堆)
    // Java 的 PriorityQueue 默认是小顶堆,添加 comparator 参数使其变成最大堆
    Queue<Integer> heap = new PriorityQueue<>(k, (i1, i2) -> Integer.compare(i2, i1));

    for (int e : arr) {
        // 当前数字小于堆顶元素才会入堆
        if (heap.isEmpty() || heap.size() < k || e < heap.peek()) {
            heap.offer(e);
        }
        if (heap.size() > k) {
            heap.poll(); // 删除堆顶最大元素
        }
    }

    // 将堆中的元素存入数组
    int[] res = new int[heap.size()];
    int j = 0;
    for (int e : heap) {
        res[j++] = e;
    }
    return res;
}

算法的复杂度分析:

  • 由于使用了一个大小为 k 的堆,空间复杂度为 $O(k)$;
  • 入堆和出堆操作的时间复杂度均为 $O(logk)$,每个元素都需要进行一次入堆操作,故算法的时间复杂度为 $O(nlogk)$。

看起来分治法的快速选择算法的时间、空间复杂度都优于使用堆的方法,但是要注意到快速选择算法的几点局限性:

  • 第一,算法需要修改原数组,如果原数组不能修改的话,还需要拷贝一份数组,空间复杂度就上去了。
  • 第二,算法需要保存所有的数据。如果把数据看成输入流的话,使用堆的方法是来一个处理一个,不需要保存数据,只需要保存 k 个元素的最大堆。而快速选择的方法需要先保存下来所有的数据,再运行算法。当数据量非常大的时候,甚至内存都放不下的时候,就麻烦了。所以当数据量大的时候还是用基于堆的方法比较好。

数据可以分区然后每个分区求 Top K, 最后在所有的 Top K 中找 Top K。

二叉搜索树

时间复杂度 $O(NlogK)$

BST 相对于前两种方法没那么常见,但是也很简单,和大根堆的思路差不多。 要提的是,与前两种方法相比,BST 有一个好处是求得的前K大的数字是有序的。

因为有重复的数字,所以用的是 TreeMap 而不是 TreeSet。

TreeMap的key 是数字,value 是该数字的个数。

我们遍历数组中的数字,维护一个数字总个数为 K 的 TreeMap:

  1. 若目前 map 中数字个数小于 K,则将 map 中当前数字对应的个数 +1;
  2. 否则,判断当前数字与 map 中最大的数字的大小关系:若当前数字大于等于 map 中的最大数字,就直接跳过该数字;若当前数字小于 map 中的最大数字,则将 map 中当前数字对应的个数 +1,并将 map 中最大数字对应的个数减 1。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        if (k == 0 || arr.length == 0) {
            return new int[0];
        }
        // TreeMap的key是数字, value是该数字的个数。
        // cnt表示当前map总共存了多少个数字。
        TreeMap<Integer, Integer> map = new TreeMap<>();
        int cnt = 0;
        for (int num: arr) {
            // 1. 遍历数组,若当前map中的数字个数小于k,则map中当前数字对应个数+1
            if (cnt < k) {
                map.put(num, map.getOrDefault(num, 0) + 1);
                cnt++;
                continue;
            } 
            // 2. 否则,取出map中最大的Key(即最大的数字), 判断当前数字与map中最大数字的大小关系:
            //    若当前数字比map中最大的数字还大,就直接忽略;
            //    若当前数字比map中最大的数字小,则将当前数字加入map中,并将map中的最大数字的个数-1。
            Map.Entry<Integer, Integer> entry = map.lastEntry();
            if (entry.getKey() > num) {
                map.put(num, map.getOrDefault(num, 0) + 1);
                if (entry.getValue() == 1) {
                    map.pollLastEntry();
                } else {
                    map.put(entry.getKey(), entry.getValue() - 1);
                }
            }
            
        }

        // 最后返回map中的元素
        int[] res = new int[k];
        int idx = 0;
        for (Map.Entry<Integer, Integer> entry: map.entrySet()) {
            int freq = entry.getValue();
            while (freq-- > 0) {
                res[idx++] = entry.getKey();
            }
        }
        return res;
    }
}

数据范围有限时直接计数排序

时间复杂度 $O(N)$

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        if (k == 0 || arr.length == 0) {
            return new int[0];
        }
        // 统计每个数字出现的次数
        int[] counter = new int[10001];
        for (int num: arr) {
            counter[num]++;
        }
        // 根据counter数组从头找出k个数作为返回结果
        int[] res = new int[k];
        int idx = 0;
        for (int num = 0; num < counter.length; num++) {
            while (counter[num]-- > 0 && idx < k) {
                res[idx++] = num;
            }
            if (idx == k) {
                break;
            }
        }
        return res;
    }
}

前K个高频元素

给你一个整数数组 nums 和一个整数 k ,请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
    public int[] topKFrequent(int[] nums, int k) {
        Map<Integer, Integer> occurrences = new HashMap<Integer, Integer>();
        for (int num : nums) {
            occurrences.put(num, occurrences.getOrDefault(num, 0) + 1);
        }

        // int[] 的第一个元素代表数组的值,第二个元素代表了该值出现的次数
        PriorityQueue<int[]> queue = new PriorityQueue<int[]>(new Comparator<int[]>() {
            public int compare(int[] m, int[] n) {
                return m[1] - n[1];
            }
        });
        for (Map.Entry<Integer, Integer> entry : occurrences.entrySet()) {
            int num = entry.getKey(), count = entry.getValue();
            if (queue.size() == k) {
                if (queue.peek()[1] < count) {
                    queue.poll();
                    queue.offer(new int[]{num, count});
                }
            } else {
                queue.offer(new int[]{num, count});
            }
        }
        int[] ret = new int[k];
        for (int i = 0; i < k; ++i) {
            ret[i] = queue.poll()[0];
        }
        return ret;
    }
}

复杂度分析

  • 时间复杂度:$O(Nlogk)$,其中 N 为数组的长度。我们首先遍历原数组,并使用哈希表记录出现次数,每个元素需要 $O(1)$ 的时间,共需 $O(N)$ 的时间。随后,我们遍历「出现次数数组」,由于堆的大小至多为 k,因此每次堆操作需要 $O(logk)$ 的时间,共需 $O(Nlogk)$ 的时间。二者之和为 $O(Nlogk)$。
  • 空间复杂度:$O(N)$。哈希表的大小为 $O(N)$,而堆的大小为 $O(k)$,共计为 $O(N)$。

前K个高频单词

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
    public List<String> topKFrequent(String[] words, int k) {
        Map<String, Integer> cnt = new HashMap<String, Integer>();
        for (String word : words) {
            cnt.put(word, cnt.getOrDefault(word, 0) + 1);
        }
        PriorityQueue<Map.Entry<String, Integer>> pq = new PriorityQueue<Map.Entry<String, Integer>>(new Comparator<Map.Entry<String, Integer>>() {
            public int compare(Map.Entry<String, Integer> entry1, Map.Entry<String, Integer> entry2) {
                return entry1.getValue() == entry2.getValue() ? entry2.getKey().compareTo(entry1.getKey()) : entry1.getValue() - entry2.getValue();
            }
        });
        for (Map.Entry<String, Integer> entry : cnt.entrySet()) {
            pq.offer(entry);
            if (pq.size() > k) {
                pq.poll();
            }
        }
        List<String> ret = new ArrayList<String>();
        while (!pq.isEmpty()) {
            ret.add(pq.poll().getKey());
        }
        Collections.reverse(ret);
        return ret;
    }
}

复杂度分析

  • 时间复杂度:$O(l×n+m×llogk)$,其中 n 表示给定字符串序列的长度,m 表示实际字符串种类数,l 表示字符串的平均长度。我们需要 l×n 的时间将字符串插入到哈希表中,以及每次插入元素到优先队列中都需要 $llogk$ 的时间,共需要插入 m 次。
  • 空间复杂度:$O(l×(m+k))$,其中 l 表示字符串的平均长度,m 表示实际字符串种类数。哈希表空间占用为 $O(l×m)$,优先队列空间占用为 $O(l×k)$。

数组中的第K个最大的元素

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Solution {
    public int findKthLargest(int[] nums, int k) {
        int heapSize = nums.length;
        buildMaxHeap(nums, heapSize);
        for (int i = nums.length - 1; i >= nums.length - k + 1; --i) {
            swap(nums, 0, i);
            --heapSize;
            maxHeapify(nums, 0, heapSize);
        }
        return nums[0];
    }

    public void buildMaxHeap(int[] a, int heapSize) {
        for (int i = heapSize / 2; i >= 0; --i) {
            maxHeapify(a, i, heapSize);
        } 
    }

    public void maxHeapify(int[] a, int i, int heapSize) {
        int l = i * 2 + 1, r = i * 2 + 2, largest = i;
        if (l < heapSize && a[l] > a[largest]) {
            largest = l;
        } 
        if (r < heapSize && a[r] > a[largest]) {
            largest = r;
        }
        if (largest != i) {
            swap(a, i, largest);
            maxHeapify(a, largest, heapSize);
        }
    }

    public void swap(int[] a, int i, int j) {
        int temp = a[i];
        a[i] = a[j];
        a[j] = temp;
    }
}

复杂度分析

  • 时间复杂度:$(nlogn)$,建堆的时间代价是 $O(n)$,删除的总代价是 $O(klogn)$,因为 k<n,故渐进时间复杂为 $O(n+klogn)=O(nlogn)$。
  • 空间复杂度:$O(logn)$,即递归使用栈空间的空间代价。

附录