「前k个高频元素」解法优化过程

Top K 问题,时间复杂度从 O(nlogn)\Omicron(n\log n) 逐渐优化到 O(n)\Omicron(n)

题目来自leetcode347. 前 K 个高频元素,与215. 数组中的第K个最大元素思路相同

问题

问题描述

给你一个整数数组 nums 和一个整数 k ,请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。

示例 1:

1
2
输入: nums = [1,1,1,2,2,3], k = 2
输出: [1,2]

示例 2:

1
2
输入: nums = [1], k = 1
输出: [1]

分析

思路很简单

  1. 使用哈希表统计不同元素的频次
  2. 计算出前 kk 高频率的元素

直观思路:大根堆

<元素-频次>哈希表构建

1
2
3
4
5
6
Map<Integer, Integer> mp = new HashMap<>();
for (int x : nums) {
int freq = mp.getOrDefault(x, 0);
freq++;
mp.put(x, freq);
}

以频次为索引构建一个大根堆,然后弹出 kk 次堆顶元素,即为所求

优先队列实现

1
2
3
4
5
6
7
8
PriorityQueue<int[]> q = new PriorityQueue<>((a,b) -> {return b[1] - a[1];});
for (Map.Entry<Integer, Integer> entry : mp.entrySet()) {
q.add(new int[]{entry.getKey(), entry.getValue()});
}
int[] res = new int[k];
for (int i = 0; i < k; ++i) {
res[i] = q.poll()[0];
}

手写大根堆

参考acwing算法基础课笔记-堆排序,使用数组模拟大根堆

1
2
3
4
5
6
7
8
9
10
11
12
int n = mp.size();
size = n;
int idx = 1;
for (Map.Entry<Integer, Integer> entry : mp.entrySet()) {
h[idx++] = new int[]{entry.getKey(), entry.getValue()};
}
for (int i = n / 2; i >= 1; --i) down(i);
int[] res = new int[k];
for (int i = 0; i < k; ++i) {
res[i] = h[1][0];
h[1] = h[size]; size--; down(1); // 弹出堆顶元素
}

完整代码

使用二维数组 h[n][2]h[n][2] 保存大根堆,h[idx][0]代表元素,h[idx][1]代表频次

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution {
static final int N = (int) 1e5+10;
static int[][] h = new int[N][2];
int size;
void down(int u) {
int t = u;
if (2 * u <= size && h[2 * u][1] > h[t][1]) t = 2 * u;
if (2 * u + 1 <= size && h[2 * u + 1][1] > h[t][1]) t = 2 * u + 1;
if (t != u) {
int[] tmp = h[t]; h[t] = h[u]; h[u] = tmp;
down(t);
}
}
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> mp = new HashMap<>();
for (int x : nums) {
int freq = mp.getOrDefault(x, 0);
freq++;
mp.put(x, freq);
}
int n = mp.size();
size = n;
int idx = 1;
for (Map.Entry<Integer, Integer> entry : mp.entrySet()) {
h[idx++] = new int[]{entry.getKey(), entry.getValue()};
}
for (int i = n / 2; i >= 1; --i) down(i);
int[] res = new int[k];
for (int i = 0; i < k; ++i) {
res[i] = h[1][0];
h[1] = h[size]; size--; down(1);
}
return res;
}
}

时间复杂度分析

建立频次哈希表的复杂度是 O(n)\Omicron(n)

大根堆的建立过程中,执行了 n / 2 次向下调整

1
for (int i = n / 2; i >= 1; --i) down(i);

复杂度为 O(nlogn)\Omicron(n\log n)

输出前 kk 高频次的元素,复杂度是 O(klogn)\Omicron(k\log n)

整体时间复杂度为 O(nlogn)\Omicron(n\log n)

不满足题目要求

这样还不如直接排序,然后取前 kk 个元素来得直接点

1
Arrays.sort(q, (a, b) -> {return b[1] - a[1];})

改用小根堆

优化思路

不如换个思路,将堆转为小根堆,始终维持 kk 个元素

当堆的元素个数小于 kk 时,直接插入

等于 kk 时:

  • 若当前待插入元素的频次小于堆顶元素的,由于是小根堆,说明至少有 kk 个元素的频次比当前元素高,可以直接忽略
  • 若频次大于堆顶元素时,将堆顶元素弹出,并插入当前元素

因为堆中始终只有 kk 个元素,所以时间复杂度降为 O(nlogk)\Omicron(n\log k)

代码实现

优先队列实现版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution {
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> mp = new HashMap<>();
for (int x : nums) {
int freq = mp.getOrDefault(x, 0);
freq++;
mp.put(x, freq);
}
PriorityQueue<int[]> q = new PriorityQueue<>((a,b) -> {return a[1] - b[1];});
for (Map.Entry<Integer, Integer> entry : mp.entrySet()) {
int key = entry.getKey(), val = entry.getValue();
if (q.size() < k) {
q.add(new int[]{key, val});
} else {
if (q.peek()[1] < val) {
q.poll();
q.add(new int[]{key, val});
}
}
}
int[] res = new int[k];
int i = 0;
for (int[] x : q) {
res[i++] = x[0];
}
return res;
}
}

手写小根堆代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class Solution {
static final int N = (int) 1e5+10;
static int[][] h = new int[N][2];
int size;
void up(int u) {
while (u / 2 >= 1 && h[u / 2][1] > h[u][1]) {
int[] tmp = h[u / 2]; h[u / 2] = h[u]; h[u] = tmp;
u >>= 1;
}
}
void down(int u) {
int t = u;
if (2 * u <= size && h[2 * u][1] < h[t][1]) t = 2 * u;
if (2 * u + 1 <= size && h[2 * u + 1][1] < h[t][1]) t = 2 * u + 1;
if (t != u) {
int[] tmp = h[t]; h[t] = h[u]; h[u] = tmp;
down(t);
}
}
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> mp = new HashMap<>();
for (int x : nums) {
int freq = mp.getOrDefault(x, 0);
freq++;
mp.put(x, freq);
}
int n = mp.size();
size = 0;
int idx = 1;
for (Map.Entry<Integer, Integer> entry : mp.entrySet()) {
int key = entry.getKey(), val = entry.getValue();
if (size < k) {
h[++size] = new int[]{key, val}; up(size);
} else {
if (h[1][1] <= val) {
h[1] = h[size]; size--; down(1);
h[++size] = new int[]{key, val}; up(size);
}
}
}

int[] res = new int[k];
for (int i = 0; i < k; ++i) {
res[i] = h[i+1][0];
}
return res;
}
}

最优解法:快速选择算法

参考 acwing算法基础课笔记-快速排序

核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
void quickSelect(List<Integer> res, int[][] q, int l, int r, int k) {
if (l == r) {
res.add(q[l][0]);
return;
}
int x = q[l + r >> 1][1], i = l - 1, j = r + 1;
while (i < j) {
while (q[++i][1] > x);
while (q[--j][1] < x);
if (i < j) {
int[] tmp = q[i]; q[i] = q[j]; q[j] = tmp;
}
}
int kl = j - l + 1;
if (k <= kl) {
quickSelect(res, q, l, j, k);
} else {
for (int t = l; t <= j; ++t) res.add(q[t][0]);
quickSelect(res, q, j + 1, r, k - kl);
}
}

快速选择算法是基于快速排序的

每一次根据枢轴值pivot划分区间,分成左右两部分

而快选是根据 kk 值,每次只选择其中一个部分,继续下一次划分

假设每次等半划分,也就是说,遍历的长度依次为 n, n2, n41n,\ \frac{n}2,\ \frac{n}4\dots1

求和得 n(112x)112=2n(112x)\frac{n(1-{\frac{1}{2}}^x)}{1-\frac{1}{2}}=2n(1-{\frac{1}{2}}^x) ,其中 xx 为划分次数,x=log2nx=\lceil{\log_{2}n}\rceil

带入,总的遍历长度大致为 2n2n-22 ,也就是说时间复杂度为 O(n)\Omicron(n)

当然,这是理想下的平均情况,实际复杂度与每次枢轴pivot的选取密切相关,

如果每次枢轴都将整体区间划分得极不均衡,例如 [1,,n1,pivot][1,\dots,n-1,pivot] 位于端点这种情况

时间复杂度就退化为了 O(n2)\Omicron(n^2)

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution {
public int[] topKFrequent(int[] nums, int k) {
Map<Integer, Integer> mp = new HashMap<>();
for (int x : nums) {
int freq = mp.getOrDefault(x, 0);
freq++;
mp.put(x, freq);
}
int n = mp.size();
int[][] q = new int[n][2];
int idx = 0;
for (Map.Entry<Integer, Integer> entry : mp.entrySet()) {
q[idx++] = new int[]{entry.getKey(), entry.getValue()};
}
List<Integer> res = new ArrayList<>();
quickSelect(res, q, 0, n - 1, k);
int[] ans = new int[k];
for (int i = 0; i < k; ++i) ans[i] = res.get(i);
return ans;
}
void quickSelect(List<Integer> res, int[][] q, int l, int r, int k) {
if (l == r) {
res.add(q[l][0]);
return;
}
int x = q[l + r >> 1][1], i = l - 1, j = r + 1;
while (i < j) {
while (q[++i][1] > x);
while (q[--j][1] < x);
if (i < j) {
int[] tmp = q[i]; q[i] = q[j]; q[j] = tmp;
}
}
int kl = j - l + 1;
if (k <= kl) {
quickSelect(res, q, l, j, k);
} else {
for (int t = l; t <= j; ++t) res.add(q[t][0]);
quickSelect(res, q, j + 1, r, k - kl);
}
}
}

注意:本题要求结果按照任意顺序返回即可

快速选择算法求出的结果不一定是有序的,因为是按枢轴划分,在 pivot 之前的比 pivot 小,但不保证顺序!