Statement

给出 nn 个模板串 SiS_i,求合法字符串个数,满足:可以被划分成非空的两段,使得其中每一段都是一个模板串的前缀。

数据范围:n104n\le10^4Si30|S_i|\le30Σ26|\Sigma|\le26


Solution

如果看过网上的一些题解觉得没懂,那么你可以考虑看看这篇题解。

考虑用字符串总数减去重复字符串个数。对模板串建立 Trie,全集为 Trie 树节点数(不含根节点)的平方。

重复字符串形如:1 2 341 23 412 3 4\begin{aligned}1~2~3|4\\1~2|3~4\\1|2~3~4\end{aligned},其中每个数字表示一段字符串。这个字符串在枚举 443 43~42 3 42~3~4 时都会被枚举到。

可以发现 AC 自动机上 3 43~4 的 fail 正好指向 442 3 42~3~4 的 fail 正好指向 3 43~4;所以考虑在枚举到 3 43~4 时统计前两个字符串的重复,在枚举到 2 3 42~3~4 时统计后两个字符串的重复。这个重复,就是 1 21~211 对应的模板串前缀个数。

假设枚举到 2 3 42~3~4,找到其 fail 指向 3 43~4,那么如何统计合法的 11 的个数呢?显然合法的 11 需要是 1 21~2 的前缀。所以合法的 11 的个数等于 22 在 fail 树上的子树大小(减一,因为 11 不能为空)。

注意到 22 一定是 2 3 42~3~4 的祖先,且在 Trie 上的深度为 dep(2 3 4)dep(3 4)\text{dep}(2~3~4)-\text{dep}(3~4)。所以可以在遍历 Trie 树时维护最右链。


Code

核心代码

1
2
3
4
5
6
void dfs(int u) {
anc[dep[u]] = u;
if (fail[u]) ans -= siz[anc[dep[u] - dep[fail[u]]]] - 1;
for (int i = 0; i < 26; i++)
if (trie[u][i]) dfs(trie[u][i]);
}

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include <bits/stdc++.h>
typedef long long LL;
const int N = 3e5 + 7;
namespace AC {
int son[N][26], pnodes = 0, trie[N][26];
void insert(char* str) {
int u = 0;
while (*str) {
if (!son[u][*str - 'a']) son[u][*str - 'a'] = ++pnodes;
u = son[u][*str - 'a'], str++;
}
}
int fail[N], dep[N], siz[N];
std::vector<int> edge[N];
void build() {
std::queue<int> q;
memcpy(trie, son, sizeof trie);
for (int i = 0; i < 26; i++)
if (son[0][i]) q.push(son[0][i]);
while (!q.empty()) {
int u = q.front();
q.pop();
for (int i = 0; i < 26; i++)
if (son[u][i])
fail[son[u][i]] = son[fail[u]][i], q.push(son[u][i]);
else
son[u][i] = son[fail[u]][i];
}
for (int i = 1; i <= pnodes; i++) edge[fail[i]].push_back(i);
}
void pre1(int u) {
for (int i = 0; i < 26; i++)
if (trie[u][i]) dep[trie[u][i]] = dep[u] + 1, pre1(trie[u][i]);
}
void pre2(int u) {
siz[u] = 1;
for (int v : edge[u]) pre2(v), siz[u] += siz[v];
}
LL ans = 0;
int anc[N];
void dfs(int u) {
anc[dep[u]] = u;
if (fail[u]) ans -= siz[anc[dep[u] - dep[fail[u]]]] - 1;
for (int i = 0; i < 26; i++)
if (trie[u][i]) dfs(trie[u][i]);
}
LL solve() { return ans = (LL)pnodes * pnodes, pre1(0), pre2(0), dfs(0), ans; }
} // namespace AC
using namespace AC;
int n;
char str[N];
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%s", str), insert(str);
build(), printf("%lld\n", solve());
return 0;
}