题意:给定一棵树,每个节点有一个颜色,问树上有多少种子串(定义子串为某两个点上的路径),保证叶子节点数<=20。n<=10^5
题解:
叶子节点小于等于20,考虑将每个叶子节点作为根把树给提起来形成一棵trie,然后定义这棵树的子串为从上到下的一个串(深度从浅到深)。
这样做我们可以发现每个子串必定是某棵trie上的一段直线。统计20棵树的不同子串只需要把它们建到一个自动机上就行了,相当于把20棵trie合并成一棵大的。
对于每个节点x,它贡献的子串数量是max[x]-min[x],又因为min[x]=max[fa]+1,则=max[x]-max[fa],就是step[x]-step[fa];学会了怎样在sam上插入一颗trie,就直接记录一下父亲在sam上的节点作为p。注意每次都要新开一个点,不然会导致无意义的子串出现。例如一棵树 (括号内为i颜色)
1(0)
2(1)
3(2) 4(3)
2是1的孩子,3和4都是2的孩子。在以1为根节点的时候插入了这棵trie,在以3为根节点的时候son[root][2]已经存在,如果用它来当现在的点的话就会让一棵trie接在另一棵的末位,导致无意义的子串出现,答案偏大。
1 #include2 #include 3 #include 4 #include 5 using namespace std; 6 7 typedef long long LL; 8 const int N=20*100010; 9 int n,c,tot,len,last;10 int w[N],son[N][15],step[N],pre[N],first[N],cnt[N],id[N];11 struct node{12 int x,y,next;13 }a[2*N];14 15 void ins(int x,int y)16 {17 a[++len].x=x;a[len].y=y;18 a[len].next=first[x];first[x]=len;19 }20 21 int add_node(int x)22 {23 step[++tot]=x;24 return tot;25 }26 27 int extend(int p,int ch)28 {29 // int np;30 // if(son[p][ch]) return son[p][ch];31 // else np=add_node(step[p]+1);32 int np=add_node(step[p]+1);//debug 每次都要新开一个点33 34 while(p && !son[p][ch]) son[p][ch]=np,p=pre[p];35 if(p==0) pre[np]=1;36 else 37 {38 int q=son[p][ch];39 if(step[q]==step[p]+1) pre[np]=q;40 else 41 {42 int nq=add_node(step[p]+1);43 memcpy(son[nq],son[q],sizeof(son[q]));44 pre[nq]=pre[q];45 pre[np]=pre[q]=nq;46 while(son[p][ch]==q) son[p][ch]=nq,p=pre[p];47 }48 }49 last=np;50 return np;51 }52 53 void dfs(int x,int fa,int now)54 {55 int nt=extend(now,w[x]);56 // printf("%d\n",nt);57 for(int i=first[x];i;i=a[i].next)58 {59 int y=a[i].y;60 if(y!=fa) dfs(y,x,nt);61 }62 }63 64 int main()65 {66 freopen("a.in","r",stdin);67 scanf("%d%d",&n,&c);68 for(int i=1;i<=n;i++) scanf("%d",&w[i]);69 tot=0;len=0;70 memset(son,0,sizeof(son));71 memset(pre,0,sizeof(pre));72 memset(cnt,0,sizeof(cnt));73 memset(first,0,sizeof(first));74 step[++tot]=0;last=1;75 for(int i=1;i %d\n",a[i].x,a[i].y);83 for(int i=1;i<=n;i++)84 {85 if(cnt[i]==1) dfs(i,0,1);86 }87 // for(int i=1;i<=tot;i++) printf("%d ",id[i]);printf("\n");88 LL ans=0;89 for(int i=1;i<=tot;i++) ans+=(LL)(step[i]-step[pre[i]]);90 printf("%lld\n",ans);91 return 0;92 }