Z Algorithm-扩展KMP

zz 函数,z[i]z[i] 表示字符串 sss[i,n1]s[i,n-1] 的最长公共前缀长度。计算 z[]z[] 的算法为 Z Algorithm ,又称 扩展KMP

分析

主要参考 Z 函数|OI Wiki。学习时发现,该算法的核心思想跟Manacher算法如出一辙,都是利用到了已有的信息来加速计算

先举个例子🌰

1
2
3
4
5
6
7
8
9
10
11
s = "aaabaab"
n = len(s) = 7
z[] = [0, 2, 1, 0, 2, 1, 0]
Ex:
z[1] = 2
s[1, 6]="aabaab", s="aaabaab"
最长公共前缀 len("aa") = 2

z[4] = 2
s[4, 6]="aab", s="aaabaab"
最长公共前缀 len("aa") = 2

还是老规矩,先思考暴力做法,如何求 sss[i,n1]s[i,n-1] 的最长公共前缀

1
2
3
4
5
char[] cs = s.toCharArray();
int[] z = new int[cs.length];
for (int i = 1; i < n; ++i) {
while (i + z[i] < n && cs[z[i]] == cs[i + z[i]]) z[i]++;
}

二重循环下来,整体的时间复杂度达到了 O(n2)\Omicron(n^2)

Manacher 算法中,我们维护了具有最大回文半径的回文中心 imim 和右端点 rmrm 。在计算每一个 f[i]f[i] 时,如果 i<rmi < rm ,则利用回文串的对称性,其关于 imim 左侧的对称点 i1=2×imii^1=2\times im - i ,具有和 ii 相同的回文半径(*在 [0,rm][0,rm] 内)。而 f[i1]f[i^1] 是已知的,可以将此作为 f[i]f[i] 的初始值,继续向外拓展,更新 rmrmimim ,由于在整个过程中,维护的 rmrm 是不后退的,因此时间复杂度为 O(n)\Omicron(n)

扩展KMP算法类似的原理,记匹配段为 [l,r][l,r] (即 s[l,r]=s[0,rl]s[l,r] = s[0, r-l]),维护右端点 rr 最大的匹配段 [lm,rm][lm,rm]

在计算 z[i]z[i] 时,如果 i<rmi < rm ,由于匹配段的特性,s[i,rm]=s[ilm,rmlm]s[i,rm] = s[i-lm, rm-lm] ,可以直接利用已经计算出来的 z[ilm]z[i-lm] (其他细节见代码)

image-20230109222221999

这相当于用 O(1)\Omicron(1) 的操作获取了 O(n)\Omicron(n) 的信息,由于在计算过程中,维护的 rmrm 没有后退,整体时间复杂度为 O(n)\Omicron(n)

算法流程

第一节的分析遗漏了很多细节,接下来给出算法完整流程

来自Z函数#线性算法|OI Wiki

z[i]z[i] 表示 s[i,n1]s[i, n - 1]ss 的最长公共前缀的长度,对于 ii ,称区间 [i,i+z[i]1][i,i+z[i]-1]ii 的匹配段,或者叫做Z-box

所有的匹配段 [l,r][l,r] 中,维护右端点 rr 最大的匹配段 [lm,rm][lm, rm] ,最大(最靠右)的右端点为 rmrm

计算 z[i]z[i]

  • 如果 i<=rmi <= rm ,此时有 s[i,rm]=s[ilm,rmlm]s[i, rm] = s[i-lm, rm-lm]
    • 如果 z[ilm]<rmi+1z[i-lm] < rm-i+1 ,则 z[i]=z[ilm]z[i] = z[i-lm]
    • 如果 z[ilm]>=rmi+1z[i-lm] >= rm-i +1,此时可以将 rmi+1rm-i+1 作为 z[i]z[i] 的初始值,向右暴力枚举(此过程可以视作 rmrm 前进)
  • 如果 i>rmi > rm,说明之前的信息没有可以利用的,直接将 z[i]z[i] 初始值设为0,向右暴力枚举求出 z[i]z[i]
  • 每一轮都需要更新 rmrm ,如果当前匹配段右端点大于 rmrm ,更新 rm=i+z[i]1rm=i+z[i]-1

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int[] Z(String s) {
int n = s.length();
int[] z = new int[n];
char[] cs = s.toCharArray();
for (int i = 1, l = 0, r = 0; i < n; ++i) {
if (i <= r && z[i - l] < r - i + 1) z[i] = z[i - l];
else {
z[i] = Math.max(0, r - i + 1);
while (i + z[i] < n && cs[z[i]] == cs[i + z[i]]) z[i]++;
}
if (i + z[i] - 1 > r) {
l = i;
r = i + z[i] - 1;
}
}
return z;

}

例题

2223. 构造字符串的总得分和

题目描述

需要从空字符串开始 构造 一个长度为 n 的字符串 s ,构造的过程为每次给当前字符串 前面 添加 一个 字符。构造过程中得到的所有字符串编号为 1n ,其中长度为 i 的字符串编号为 si

  • 比方说,s = "abaca"s1 == "a"s2 == "ca"s3 == "aca" 依次类推。

si得分sisn最长公共前缀 的长度(注意 s == sn )。

给你最终的字符串 s ,请你返回每一个 si得分之和

示例

1
2
3
4
5
6
7
8
9
输入:s = "babab"
输出:9
解释:
s1 == "b" ,最长公共前缀是 "b" ,得分为 1 。
s2 == "ab" ,没有公共前缀,得分为 0 。
s3 == "bab" ,最长公共前缀为 "bab" ,得分为 3 。
s4 == "abab" ,没有公共前缀,得分为 0 。
s5 == "babab" ,最长公共前缀为 "babab" ,得分为 5 。
得分和为 1 + 0 + 3 + 0 + 5 = 9 ,所以我们返回 9 。

代码

Z算法裸题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution {
public long sumScores(String s) {
int n = s.length();
int[] z = new int[n];
z[0] = n;
char[] cs = s.toCharArray();
for (int i = 1, l = 0, r = 0; i < n; ++i) {
if (i <= r && z[i - l] < r - i + 1) z[i] = z[i - l];
else {
z[i] = Math.max(0, r - i + 1);
while (i + z[i] < n && cs[z[i]] == cs[i + z[i]]) z[i]++;
}
if (i + z[i] - 1 > r) {
l = i;
r = i + z[i] - 1;
}
}
long ans = 0;
for (int x : z) ans += x;
return ans;
}
}

字符串哈希+二分

  1. 对于每一个 sis_i,通过二分搜索,查找其与 SS 最长的相等前缀的截止位置

    • 如果当前 sis_i 的前缀 si.substr(i,mid)S.substr(0,mids_i.substr(i, mid) \neq S.substr(0, mid-i)i),则缩小前缀的长度,即向左侧收缩搜索区间,新的搜索区间更新为 [l,mid[l,mid-1]1]
    • 如果当前 sis_i 的前缀 si.substr(i,mid)=S.substr(0,mids_i.substr(i, mid) = S.substr(0, mid-i)i),则尝试增大前缀的长度,判断是否还能相等,即向右侧收缩搜索区间,更新为 [mid+1,r][mid+1,r]
  2. 判断字符串是否相等的过程,用字符串哈希的方式进行优化

时间复杂度为 O(nlog2n)\Omicron(n\log_2{n})

Java\textit Java 代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution {
static final int N = (int) 1e5+10;
static long[] h = new long[N];
static long[] p = new long[N];
static final int P = 131;
public long get(int l, int r) {
return h[r + 1] - h[l] * p[r - l + 1];
}
public long sumScores(String s) {
p[0] = 1;
int n = s.length();
for (int i = 0; i < n; ++i) {
h[i + 1] = h[i] * P + s.charAt(i);
p[i + 1] = p[i] * P;
}
long score = n;
for (int i = 1; i < n; ++i) {
int l = i, r = n - 1;
while (l <= r) {
int mid = l + r >> 1; //
if (check(i, mid)) l = mid + 1;
else r = mid - 1;
}
score += (r - i + 1);
}
return score;
}
public boolean check(int l, int r) {
return get(0, r - l) == get(l, r);
}
}