Statement
给出 n n n 个模板串 S i S_i S i ,求合法字符串个数,满足:可以被划分成非空的两段,使得其中每一段都是一个模板串的前缀。
数据范围:n ≤ 1 0 4 n\le10^4 n ≤ 1 0 4 ,∣ S i ∣ ≤ 30 |S_i|\le30 ∣ S i ∣ ≤ 3 0 ,∣ Σ ∣ ≤ 26 |\Sigma|\le26 ∣ Σ ∣ ≤ 2 6 。
Solution
如果看过网上的一些题解觉得没懂,那么你可以考虑看看这篇题解。
考虑用字符串总数减去重复字符串个数。对模板串建立 Trie,全集为 Trie 树节点数(不含根节点)的平方。
重复字符串形如:1 2 3 ∣ 4 1 2 ∣ 3 4 1 ∣ 2 3 4 \begin{aligned}1~2~3|4\\1~2|3~4\\1|2~3~4\end{aligned} 1 2 3 ∣ 4 1 2 ∣ 3 4 1 ∣ 2 3 4 ,其中每个数字表示一段字符串。这个字符串在枚举 4 4 4 、3 4 3~4 3 4 和 2 3 4 2~3~4 2 3 4 时都会被枚举到。
可以发现 AC 自动机上 3 4 3~4 3 4 的 fail 正好指向 4 4 4 ,2 3 4 2~3~4 2 3 4 的 fail 正好指向 3 4 3~4 3 4 ;所以考虑在枚举到 3 4 3~4 3 4 时统计前两个字符串的重复,在枚举到 2 3 4 2~3~4 2 3 4 时统计后两个字符串的重复。这个重复,就是 1 2 1~2 1 2 和 1 1 1 对应的模板串前缀个数。
假设枚举到 2 3 4 2~3~4 2 3 4 ,找到其 fail 指向 3 4 3~4 3 4 ,那么如何统计合法的 1 1 1 的个数呢?显然合法的 1 1 1 需要是 1 2 1~2 1 2 的前缀。所以合法的 1 1 1 的个数等于 2 2 2 在 fail 树上的子树大小(减一,因为 1 1 1 不能为空)。
注意到 2 2 2 一定是 2 3 4 2~3~4 2 3 4 的祖先,且在 Trie 上的深度为 dep ( 2 3 4 ) − dep ( 3 4 ) \text{dep}(2~3~4)-\text{dep}(3~4) dep ( 2 3 4 ) − dep ( 3 4 ) 。所以可以在遍历 Trie 树时维护最右链。
Code
核心代码
1 2 3 4 5 6 void dfs (int u) { anc[dep[u]] = u; if (fail[u]) ans -= siz[anc[dep[u] - dep[fail[u]]]] - 1 ; for (int i = 0 ; i < 26 ; i++) if (trie[u][i]) dfs (trie[u][i]); }
完整代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 #include <bits/stdc++.h> typedef long long LL;const int N = 3e5 + 7 ;namespace AC {int son[N][26 ], pnodes = 0 , trie[N][26 ];void insert (char * str) { int u = 0 ; while (*str) { if (!son[u][*str - 'a' ]) son[u][*str - 'a' ] = ++pnodes; u = son[u][*str - 'a' ], str++; } } int fail[N], dep[N], siz[N];std::vector<int > edge[N]; void build () { std::queue<int > q; memcpy (trie, son, sizeof trie); for (int i = 0 ; i < 26 ; i++) if (son[0 ][i]) q.push (son[0 ][i]); while (!q.empty ()) { int u = q.front (); q.pop (); for (int i = 0 ; i < 26 ; i++) if (son[u][i]) fail[son[u][i]] = son[fail[u]][i], q.push (son[u][i]); else son[u][i] = son[fail[u]][i]; } for (int i = 1 ; i <= pnodes; i++) edge[fail[i]].push_back (i); } void pre1 (int u) { for (int i = 0 ; i < 26 ; i++) if (trie[u][i]) dep[trie[u][i]] = dep[u] + 1 , pre1 (trie[u][i]); } void pre2 (int u) { siz[u] = 1 ; for (int v : edge[u]) pre2 (v), siz[u] += siz[v]; } LL ans = 0 ; int anc[N];void dfs (int u) { anc[dep[u]] = u; if (fail[u]) ans -= siz[anc[dep[u] - dep[fail[u]]]] - 1 ; for (int i = 0 ; i < 26 ; i++) if (trie[u][i]) dfs (trie[u][i]); } LL solve () { return ans = (LL)pnodes * pnodes, pre1 (0 ), pre2 (0 ), dfs (0 ), ans; }} using namespace AC;int n;char str[N];int main () { scanf ("%d" , &n); for (int i = 1 ; i <= n; i++) scanf ("%s" , str), insert (str); build (), printf ("%lld\n" , solve ()); return 0 ; }