POJ 3376

tire + manacher

题意

http://poj.org/problem?id=3376 给$n$个串,两两连接,一共$n^2$种,求其中有多少是回文, 字符串的长度和小于$2e6$

Solution

在考虑第$i$个串$t$和多少个串$s$拼接是回文时,计算$s+t$是回文的数量,将所有数量累加就是最终的回文数。

首先考虑什么情况下,两个串拼接会是一个回文串。

case 1: s=reverse(t)

$$ \begin{cases} s = a_0\dots a_{n-1} \newline t = a_{n-1}\dots a_0 \end{cases} $$

case 2

$$ \begin{cases} \begin{matrix} &\mathrm{palindrome} \newline s = a_0\dots a_i &\overbrace{a_{i+1}\dots a_{n-1}} \end{matrix} \newline \newline t = a_i\dots a_0 \end{cases} $$ 以及对称的情况 $$ \begin{cases} s = a_{n-1}\dots a_i \newline \begin{matrix} \ \ \ \ \ \mathrm{palindrome} &\newline t = \overbrace{a_0\dots a_{i-1}} &a_i\dots a_{n-1} \end{matrix} \end{cases} $$

因此可以对每个串,求出前缀和后缀是否是回文。把所有的串放到tire树上,统计答案。

用stl可能会TLE

#include <cstdio>
#include <stack>
#include <set>
#include <cmath>
#include <map>
#include <time.h>
#include <vector>
#include <iostream>
#include <string>
#include <cstring>
#include <algorithm>
//#include <memory.h>
#include <cstdlib>
#include <queue>
#include <iomanip>
#include <cassert>
// #include <unordered_map>
#define P pair<int, int>
#define LL long long
#define LD long double
#define PLL pair<LL, LL>
#define mset(a, b) memset(a, b, sizeof(a))
#define rep(i, a, b) for (int i = a; i < b; i++)
#define PI acos(-1.0)
#define random(x) rand() % x
#define debug(x) cout << #x << " " << x << "\n"
using namespace std;
const int inf = 0x3f3f3f3f;
const LL __64inf = 0x3f3f3f3f3f3f3f3f;
#ifdef DEBUG
const int MAX = 2e6 + 50;
#else
const int MAX = 2e6  + 50;
#endif
const int mod = 998244353;

int str_len[MAX];
char str[MAX<<1];
char* str_p[MAX];
int sufix_palin[MAX];
int prefix_parlin[MAX];
int *sufix_palin_p[MAX];
int *prefix_parlin_p[MAX];
int radius[MAX];
int n;
int zero;
int palin;
struct Tire
{
    int tot;
    int nex[MAX][26];
    int siz[MAX];
    int _siz[MAX];
    int rest_palin[MAX];
    void init(){
        tot = 1;
        mset(nex[0], 0);
        siz[0] = 0;
        _siz[0] = 0;
        rest_palin[0] = 0;
    }
    void add(const int idx, char *s){
        int last = 0;
        int rt = 0;
        for(int i = 0; i < str_len[idx]; i++){
            int c = s[i] - 'a';
            if(!nex[rt][c] ){
                mset(nex[tot], 0);
                siz[tot] = 0;
                _siz[tot] = 0;
                rest_palin[tot] = 0;
                nex[rt][c] = tot ++;
            }
            last = rt;
            rt = nex[rt][c];
            if(i + 2 == str_len[idx])  _siz[rt]++;
            if(sufix_palin_p[idx][i]){
                rest_palin[last]++;
            }
        }
        siz[rt]++;
    }
    int* operator [] (int idx){
        return nex[idx];
    }
}tire;

string manacherStr(int len, char *s){
    string res = "#";
    for(int i = 0; i < len; i++){
        res += s[i];
        res += '#';
    }
    return res;
}
void solve(const int idx, char *s){
    if(str_len[idx] == 0) return ;
    int R = -1;
    int C = -1;
    int Max = -1;
    string str = manacherStr(str_len[idx], s);
    for(int i = 0; i < str.size(); i++){
        radius[i] = R > i ? min(radius[2*C-i], R-i+1) : 1;
        while(i + radius[i] < str.size() and i - radius[i] > -1) {
            if(str[i-radius[i]] == str[i+radius[i]])
                radius[i]++;
            else break;
        }
        if(i + radius[i] > R) {
            R = i + radius[i]-1;
            C = i;
        }
        Max = max(Max, radius[i]);
    }
    if(Max - 1 == str_len[idx]) palin++;
    for(int i = 1; i+1 < str.size(); i++){
        if(str[i] == '#') {
            int len = radius[i]-1;
            int r = len / 2;
            int start_idx = (i-2) / 2 - r + 1;
            if(start_idx + len - 1 == str_len[idx]-1)
                sufix_palin_p[idx][start_idx] = 1;
            if(start_idx == 0) 
                prefix_parlin_p[idx][start_idx+len-1] = 1;
        }
        else {
            int len = radius[i]-1;
            int r = len / 2;
            int start_idx = (i-1)/2 - r;
            if(start_idx + len - 1 == str_len[idx]-1)
                sufix_palin_p[idx][start_idx] = 1;
            if(start_idx == 0)
                prefix_parlin_p[idx][start_idx+len-1] = 1;
        }   
    }
}
LL calcu(const int idx, char* s){
    int rt = 0;
    LL res = 0;
    for(int i = str_len[idx]-1; i >= 0; i--){
        int  c = s[i] - 'a';
        rt = tire[rt][c] ;
        if(!rt) break;
        if(i == 0){
            res += (LL)tire.siz[rt];
            res += (LL)tire.rest_palin[rt];
        }
        else if(i >= 1 and prefix_parlin_p[idx][i-1]){
            res += (LL)tire.siz[rt];
        }
    }
    return res;
}
int main(){
#ifdef DEBUG
    freopen("in", "r", stdin);
#endif
    tire.init();
    scanf("%d", &n);
    int cur = 0;
    for(int i = 0; i < n; i++) {
        int x;
        scanf("%d%s", &x, &str[cur]);
        zero += x == 0;
        str_p[i] = str+cur;
        str_len[i] = x;
        prefix_parlin_p[i] = prefix_parlin+cur;
        sufix_palin_p[i] = sufix_palin +cur;
        cur += x;
    }   
    for(int i = 0; i < n; i++){
        solve(i, str_p[i]);
    }
    for(int i = 0; i < n; i++) 
        tire.add(i, str_p[i]);
    LL ans = 0;
    tire.rest_palin[0] = tire.siz[0] = 0;
    for(int i = 0; i < n; i++){
    #ifdef DEBUG
        printf("debug i:%d calcu:%d\n", i, calcu(i, str_p[i]));
    #endif
        ans += 1LL * calcu(i, str_p[i]);        
    }
    ans += 1LL * zero * zero;
    ans += 1LL * zero * palin * 2;
#ifdef DEBUG
    for(int i = 0; i < n; i++){
        for(int j = 0; j < str_len[i]; j++)
            cout << prefix_parlin_p[i][j] << " ";
        puts("");
        for(int j = 0; j < str_len[i]; j++)
            cout << sufix_palin_p[i][j] << " " ;
        puts("\n"); 
    }
#endif
    printf("%lld\n", ans);
    return 0;
}