线性DP问题:最长上升子序列(LIS)、最短编辑距离、最长公共子序列(LCS)
leetcode→ 300. 最长递增子序列
问题描述
上述例子中的最长子序列为 [1,2,5,6] 长度为4
Y总板书
转移方程
f[i]=max{f[j]+1},j=0,1,2⋯i−1,when qi>qj
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| import java.util.*;
class Main{ static final int N = 1010; static int[] q = new int[N]; static int[] f = new int[N]; public static void main(String[] args) { Scanner cin = new Scanner(System.in); int n; n = cin.nextInt(); for (int i = 0; i < n; ++i) q[i] = cin.nextInt(); Arrays.fill(f, 1); for (int i = 1; i < n; ++i) { for (int j = 0; j < i; ++j) { if (q[i] > q[j]) f[i] = Math.max(f[i], f[j] + 1); } } int res = 0; for (int i = 0; i < n; ++i) res = Math.max(res,f[i]); System.out.print(res); } }
|
优化
上述算法的时间复杂度是 O(n2) 的,如果数据量过大,会超时。
思路
贪心
为什么这么做?
因为末尾元素越小,越容易在后接上一个较大的元素,构成更长的上升子序列。
维持一个数组 q , q[k] 代表着长度为 k 的上升子序列的末尾元素的最小值
这个数组,是严格单调递增的。
也就是说 长度为6的上升子序列末尾元素的最小值 一定 严格大于 长度为5的上升子序列末尾元素最小值
简单地反证一下:
假设 q[k]=q[k−1] ,考虑 k 长度的上升子序列,在其末尾元素之前,一定存在一个最近的元素 ai<q[k] ,且 ai<q[k−1] , 则长度为 k-1 的上升子序列末尾元素最小值为 ai ,这就与 q[k−1] 的意义矛盾了。
一旦一个数组具有单调性,则其具备二段性
- 序号大于 k 的元素,一定比 q[k] 大
- 序号小于 k 的元素,一定比 q[k] 小
则可以使用二分法
1 2 3 4 5 6 7 8 9 10 11
| int len = 0;
int l = 0, r = len; while (l < r) { int mid = l + r + 1 >> 1; if (q[mid] < a[i]) l = mid; else r = mid - 1; } q[r + 1] = a[i]; len = Math.max(len, r + 1);
|
完整代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| import java.util.*;
class Main{ static final int N = 1010; static int[] a = new int[N]; static int[] q = new int[N]; public static void main(String[] args) { Scanner cin = new Scanner(System.in); int n; n = cin.nextInt(); for (int i = 0; i < n; ++i) a[i] = cin.nextInt(); int len = 0; for (int i = 0; i < n; ++i) { int l = 0, r = len; while (l < r) { int mid = l + r + 1 >> 1; if (q[mid] < a[i]) l = mid; else r = mid - 1; } len = Math.max(len, r + 1); q[r + 1] = a[i]; } System.out.print(len); } }
|
二分查找部分的思考
二分查找部分,可以换种写法
查找大于等于a[i]的最小值,替换掉
1 2 3 4 5 6 7 8
| int l = 1, r = len; while (l < r) { int mid = l + r >> 1; if (q[mid] >= a[i]) r = mid; else l = mid + 1; } q[l] = a[i];
|
为什么在求”小于a[i]的最大值"时,要将左边界设为0呢?
因为如果q中所有值都比a[i]大或者相等,合法的更新位置为初始位置 1,
如果左边界设为1的话,终止循环时,r指针为1,r+1就为2了,错误!
也可以换种写法规避这种情况
1 2 3 4 5 6 7 8
| int l = 1, r = len; while (l <= r) { int mid = l + r >> 1; if (q[mid] < a[i]) l = mid + 1; else r = mid - 1; } q[r + 1] = a[i];
|
其实,「找到小于x的最大值,并在右边插入x」等价于「找到大于等于x的最小值,用x替换」
参考一下链接
二分查找有几种写法?它们的区别是什么? - Jason Li的回答 - 知乎
题目描述
dp问题
需要积累经验
Y总板书
状态有四种,以此进行分析,得出状态转移方程。
- 包含 ai 和 bj
- 包含 ai 不包含 bj
- 不包含 ai 包含 bj
- 既不包含 ai 也不包含 bj
状态1,即当前公共子序列的集合包含 ai 和 bj ,只有一种可能:a[i]=b[j]
状态2,即当前公共子序列的集合里面不包含 ai ,包含 bj
- 状态2是 f[i−1][j] 的子集
- 在计算最长公共子序列时,是求 max ,
- 父集合的最大值一定也是子集合的最大值。
- 状态3和状态2对称,思路一一样
状态4,即既不包含 ai 也不包含 bj ,等价于 f[i−1][j−1] ,但这是前面状态的子集,可以合并到前面一起计算
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| import java.util.*;
class Main{ static final int N = 1010; static int[][] dp = new int[N][N]; public static void main(String[] args) { Scanner cin = new Scanner(System.in); int n, m; n = cin.nextInt(); m = cin.nextInt(); String str1 = cin.next(); String str2 = cin.next(); for (int i = 0; i < n; ++i ) { for(int j = 0; j < m; ++j ) { if (str1.charAt(i) == str2.charAt(j)) { dp[i + 1][j + 1] = dp[i][j] + 1; } else { dp[i + 1][j + 1] = Math.max(dp[i][j + 1], dp[i + 1][j]); } } } System.out.print(dp[n][m]); } }
|
题目描述
思路
状态表示 f[i][j]
- 集合: 所有将 a[1,i] 转变为 b[1,j] 的操作
- 属性: 操作次数的最小值
状态计算
- 删除: f[i-1][j] + 1
- 插入: f[i][j-1] + 1
- 替换: 如果 a[i]=b[j],f[i-1][j-1] + 1,如果 a[i]=b[j] ,无需操作
代码
不要忘记初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| int[][] dp = new int[n + 1][m + 1];
for (int j = 0; j <= m; ++j) { dp[0][j] = j; } for (int i = 0; i <= n; ++i) { dp[i][0] = i; } for (int i = 0; i < n; ++i) { for (int j = 0; j < m; ++j) { dp[i + 1][j + 1] = Math.min(dp[i][j + 1], dp[i + 1][j]) + 1; if (a.charAt(i) == b.charAt(j)) { dp[i + 1][j + 1] = Math.min(dp[i + 1][j + 1], dp[i][j]); } else dp[i + 1][j + 1] = Math.min(dp[i + 1][j + 1], dp[i][j] + 1); } }
|
空间优化: 滚动数组
dp[n][m] -> dp[2][m] 没啥必要
拓展
acwing 899. 编辑距离
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| import java.io.*; import java.util.*; class Main{ static BufferedReader read = new BufferedReader(new InputStreamReader(System.in)); static int INF = 0x3f3f3f3f; public static int minEdit(String s1, String s2){ int n1 = s1.length(), n2 = s2.length(); int[][] dp = new int[n1 + 1][n2 + 1]; for(int i = 1; i <= n1; i++) dp[i][0] = i; for(int i = 1; i <= n2; i++) dp[0][i] = i; for(int i = 1; i <= n1; i++){ for(int j = 1; j <= n2; j++){ dp[i][j] = INF; dp[i][j] = Math.min(dp[i - 1][j] + 1, dp[i][j - 1] + 1); dp[i][j] = Math.min(dp[i][j], dp[i - 1][j - 1] + (s1.charAt(i - 1) == s2.charAt(j - 1) ? 0: 1 )); } } return dp[n1][n2]; } public static void main(String[] args) throws Exception{ String[] ss = read.readLine().split(" "); int n = Integer.valueOf(ss[0]); int m = Integer.valueOf(ss[1]); String[] s = new String[n]; for(int i = 0; i < n; i++){ s[i] = read.readLine(); } while(m -- > 0){ int ans = 0; ss = read.readLine().split(" "); int limit = Integer.valueOf(ss[1]); for(int i = 0; i < n; i++){ if(minEdit(s[i], ss[0]) <= limit) ans++; } System.out.println(ans); }
} }
|