Top K 问题有两种不同的解法,一种解法使用堆(优先队列),另一种解法使用类似快速排序的分治法。
解题思路:
- 用快排最高效解决 Top K 问题
- 大根堆(前 K 小) / 小根堆(前 K 大), Java中有现成的 PriorityQueue,实现起来最简单
- 二叉搜索树也可以 解决 TopK 问题
- 数据范围有限时直接计数排序就行了
最小的 k 个数
输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。
快排变形
在计算机科学中,快速选择(英语:Quickselect)是一种从无序列表找到第k小元素的选择算法。
Top K 问题的另一个解法就比较难想到,需要在平时有算法的积累。实际上,“查找第 k 大的元素”是一类算法问题,称为选择问题。找第 k 大的数,或者找前 k 大的数,有一个经典的 quick select(快速选择)算法。这个名字和 quick sort(快速排序)看起来很像,算法的思想也和快速排序类似,都是分治法的思想。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
public int[] getLeastNumbers(int[] arr, int k) {
if (k == 0) {
return new int[0];
} else if (arr.length <= k) {
return arr;
}
// 原地不断划分数组
partitionArray(arr, 0, arr.length - 1, k);
// 数组的前 k 个数此时就是最小的 k 个数,将其存入结果
int[] res = new int[k];
for (int i = 0; i < k; i++) {
res[i] = arr[i];
}
return res;
}
void partitionArray(int[] arr, int lo, int hi, int k) {
// 做一次 partition 操作
int m = partition(arr, lo, hi);
// 此时数组前 m 个数,就是最小的 m 个数
if (k == m) {
// 正好找到最小的 k(m) 个数
return;
} else if (k < m) {
// 最小的 k 个数一定在前 m 个数中,递归划分
partitionArray(arr, lo, m-1, k);
} else {
// 在右侧数组中寻找最小的 k-m 个数
partitionArray(arr, m+1, hi, k);
}
}
// partition 函数和快速排序中相同,具体可参考快速排序相关的资料
// 代码参考 Sedgewick 的《算法4》
int partition(int[] a, int lo, int hi) {
int i = lo;
int j = hi + 1;
int v = a[lo];
while (true) {
while (a[++i] < v) {
if (i == hi) {
break;
}
}
while (a[--j] > v) {
if (j == lo) {
break;
}
}
if (i >= j) {
break;
}
swap(a, i, j);
}
swap(a, lo, j);
// a[lo .. j-1] <= a[j] <= a[j+1 .. hi]
return j;
}
void swap(int[] a, int i, int j) {
int temp = a[i];
a[i] = a[j];
a[j] = temp;
}
|
算法的复杂度分析:
- 空间复杂度 $O(1)$,不需要额外空间。
- 时间复杂度的分析方法和快速排序类似。由于快速选择只需要递归一边的数组,时间复杂度小于快速排序,期望时间复杂度为 $O(n)$,最坏情况下的时间复杂度为 $O(n^2)$。
大根堆(前 K 小) / 小根堆(前 K 大)
比较直观的想法是使用堆数据结构来辅助得到最小的 k 个数。堆的性质是每次可以找出最大或最小的元素。我们可以使用一个大小为 k 的最大堆(大顶堆),将数组中的元素依次入堆,当堆的大小超过 k 时,便将多出的元素从堆顶弹出。这样,由于每次从堆顶弹出的数都是堆中最大的,最小的 k 个元素一定会留在堆里。
Java 中的 PriorityQueue 就是堆的数据结构
优化,如果当前数字不小于堆顶元素,数字可以直接丢掉,不入堆。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
public int[] getLeastNumbers(int[] arr, int k) {
if (k == 0) {
return new int[0];
}
// 使用一个最大堆(大顶堆)
// Java 的 PriorityQueue 默认是小顶堆,添加 comparator 参数使其变成最大堆
Queue<Integer> heap = new PriorityQueue<>(k, (i1, i2) -> Integer.compare(i2, i1));
for (int e : arr) {
// 当前数字小于堆顶元素才会入堆
if (heap.isEmpty() || heap.size() < k || e < heap.peek()) {
heap.offer(e);
}
if (heap.size() > k) {
heap.poll(); // 删除堆顶最大元素
}
}
// 将堆中的元素存入数组
int[] res = new int[heap.size()];
int j = 0;
for (int e : heap) {
res[j++] = e;
}
return res;
}
|
算法的复杂度分析:
- 由于使用了一个大小为 k 的堆,空间复杂度为 $O(k)$;
- 入堆和出堆操作的时间复杂度均为 $O(logk)$,每个元素都需要进行一次入堆操作,故算法的时间复杂度为 $O(nlogk)$。
看起来分治法的快速选择算法的时间、空间复杂度都优于使用堆的方法,但是要注意到快速选择算法的几点局限性:
- 第一,算法需要修改原数组,如果原数组不能修改的话,还需要拷贝一份数组,空间复杂度就上去了。
- 第二,算法需要保存所有的数据。如果把数据看成输入流的话,使用堆的方法是来一个处理一个,不需要保存数据,只需要保存 k 个元素的最大堆。而快速选择的方法需要先保存下来所有的数据,再运行算法。当数据量非常大的时候,甚至内存都放不下的时候,就麻烦了。所以当数据量大的时候还是用基于堆的方法比较好。
数据可以分区然后每个分区求 Top K, 最后在所有的 Top K 中找 Top K。
二叉搜索树
时间复杂度 $O(NlogK)$
BST 相对于前两种方法没那么常见,但是也很简单,和大根堆的思路差不多。
要提的是,与前两种方法相比,BST 有一个好处是求得的前K大的数字是有序的。
因为有重复的数字,所以用的是 TreeMap 而不是 TreeSet。
TreeMap的key 是数字,value 是该数字的个数。
我们遍历数组中的数字,维护一个数字总个数为 K 的 TreeMap:
- 若目前 map 中数字个数小于 K,则将 map 中当前数字对应的个数 +1;
- 否则,判断当前数字与 map 中最大的数字的大小关系:若当前数字大于等于 map 中的最大数字,就直接跳过该数字;若当前数字小于 map 中的最大数字,则将 map 中当前数字对应的个数 +1,并将 map 中最大数字对应的个数减 1。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
class Solution {
public int[] getLeastNumbers(int[] arr, int k) {
if (k == 0 || arr.length == 0) {
return new int[0];
}
// TreeMap的key是数字, value是该数字的个数。
// cnt表示当前map总共存了多少个数字。
TreeMap<Integer, Integer> map = new TreeMap<>();
int cnt = 0;
for (int num: arr) {
// 1. 遍历数组,若当前map中的数字个数小于k,则map中当前数字对应个数+1
if (cnt < k) {
map.put(num, map.getOrDefault(num, 0) + 1);
cnt++;
continue;
}
// 2. 否则,取出map中最大的Key(即最大的数字), 判断当前数字与map中最大数字的大小关系:
// 若当前数字比map中最大的数字还大,就直接忽略;
// 若当前数字比map中最大的数字小,则将当前数字加入map中,并将map中的最大数字的个数-1。
Map.Entry<Integer, Integer> entry = map.lastEntry();
if (entry.getKey() > num) {
map.put(num, map.getOrDefault(num, 0) + 1);
if (entry.getValue() == 1) {
map.pollLastEntry();
} else {
map.put(entry.getKey(), entry.getValue() - 1);
}
}
}
// 最后返回map中的元素
int[] res = new int[k];
int idx = 0;
for (Map.Entry<Integer, Integer> entry: map.entrySet()) {
int freq = entry.getValue();
while (freq-- > 0) {
res[idx++] = entry.getKey();
}
}
return res;
}
}
|
数据范围有限时直接计数排序
时间复杂度 $O(N)$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
class Solution {
public int[] getLeastNumbers(int[] arr, int k) {
if (k == 0 || arr.length == 0) {
return new int[0];
}
// 统计每个数字出现的次数
int[] counter = new int[10001];
for (int num: arr) {
counter[num]++;
}
// 根据counter数组从头找出k个数作为返回结果
int[] res = new int[k];
int idx = 0;
for (int num = 0; num < counter.length; num++) {
while (counter[num]-- > 0 && idx < k) {
res[idx++] = num;
}
if (idx == k) {
break;
}
}
return res;
}
}
|
前K个高频元素
给你一个整数数组 nums 和一个整数 k ,请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
class Solution {
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> occurrences = new HashMap<Integer, Integer>();
for (int num : nums) {
occurrences.put(num, occurrences.getOrDefault(num, 0) + 1);
}
// int[] 的第一个元素代表数组的值,第二个元素代表了该值出现的次数
PriorityQueue<int[]> queue = new PriorityQueue<int[]>(new Comparator<int[]>() {
public int compare(int[] m, int[] n) {
return m[1] - n[1];
}
});
for (Map.Entry<Integer, Integer> entry : occurrences.entrySet()) {
int num = entry.getKey(), count = entry.getValue();
if (queue.size() == k) {
if (queue.peek()[1] < count) {
queue.poll();
queue.offer(new int[]{num, count});
}
} else {
queue.offer(new int[]{num, count});
}
}
int[] ret = new int[k];
for (int i = 0; i < k; ++i) {
ret[i] = queue.poll()[0];
}
return ret;
}
}
|
复杂度分析
- 时间复杂度:$O(Nlogk)$,其中 N 为数组的长度。我们首先遍历原数组,并使用哈希表记录出现次数,每个元素需要 $O(1)$ 的时间,共需 $O(N)$ 的时间。随后,我们遍历「出现次数数组」,由于堆的大小至多为 k,因此每次堆操作需要 $O(logk)$ 的时间,共需 $O(Nlogk)$ 的时间。二者之和为 $O(Nlogk)$。
- 空间复杂度:$O(N)$。哈希表的大小为 $O(N)$,而堆的大小为 $O(k)$,共计为 $O(N)$。
前K个高频单词
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
class Solution {
public List<String> topKFrequent(String[] words, int k) {
Map<String, Integer> cnt = new HashMap<String, Integer>();
for (String word : words) {
cnt.put(word, cnt.getOrDefault(word, 0) + 1);
}
PriorityQueue<Map.Entry<String, Integer>> pq = new PriorityQueue<Map.Entry<String, Integer>>(new Comparator<Map.Entry<String, Integer>>() {
public int compare(Map.Entry<String, Integer> entry1, Map.Entry<String, Integer> entry2) {
return entry1.getValue() == entry2.getValue() ? entry2.getKey().compareTo(entry1.getKey()) : entry1.getValue() - entry2.getValue();
}
});
for (Map.Entry<String, Integer> entry : cnt.entrySet()) {
pq.offer(entry);
if (pq.size() > k) {
pq.poll();
}
}
List<String> ret = new ArrayList<String>();
while (!pq.isEmpty()) {
ret.add(pq.poll().getKey());
}
Collections.reverse(ret);
return ret;
}
}
|
复杂度分析
- 时间复杂度:$O(l×n+m×llogk)$,其中 n 表示给定字符串序列的长度,m 表示实际字符串种类数,l 表示字符串的平均长度。我们需要 l×n 的时间将字符串插入到哈希表中,以及每次插入元素到优先队列中都需要 $llogk$ 的时间,共需要插入 m 次。
- 空间复杂度:$O(l×(m+k))$,其中 l 表示字符串的平均长度,m 表示实际字符串种类数。哈希表空间占用为 $O(l×m)$,优先队列空间占用为 $O(l×k)$。
数组中的第K个最大的元素
给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。
请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
class Solution {
public int findKthLargest(int[] nums, int k) {
int heapSize = nums.length;
buildMaxHeap(nums, heapSize);
for (int i = nums.length - 1; i >= nums.length - k + 1; --i) {
swap(nums, 0, i);
--heapSize;
maxHeapify(nums, 0, heapSize);
}
return nums[0];
}
public void buildMaxHeap(int[] a, int heapSize) {
for (int i = heapSize / 2; i >= 0; --i) {
maxHeapify(a, i, heapSize);
}
}
public void maxHeapify(int[] a, int i, int heapSize) {
int l = i * 2 + 1, r = i * 2 + 2, largest = i;
if (l < heapSize && a[l] > a[largest]) {
largest = l;
}
if (r < heapSize && a[r] > a[largest]) {
largest = r;
}
if (largest != i) {
swap(a, i, largest);
maxHeapify(a, largest, heapSize);
}
}
public void swap(int[] a, int i, int j) {
int temp = a[i];
a[i] = a[j];
a[j] = temp;
}
}
|
复杂度分析
- 时间复杂度:$(nlogn)$,建堆的时间代价是 $O(n)$,删除的总代价是 $O(klogn)$,因为
k<n
,故渐进时间复杂为 $O(n+klogn)=O(nlogn)$。
- 空间复杂度:$O(logn)$,即递归使用栈空间的空间代价。
附录