題意:給你一個字符串S,然後定義每一個字符是”好的‘或是”壞的“,求S中包含不超過k個壞字符的不同字串的個數。
思路:這道題可以用哈希和SET過的,但是太慢啦,我覺得正解應該是SA或者SAM,下面介紹SAM的做法。
我們構造S的SAM,然後在SAM的每一個狀態維護sum,表示該狀態 下的子串包含多少個”不好“的字符,po表示該狀態所表示的子串出現的位置中的一個(隨便哪一個)。我們再將SAM進行拓撲排序,然後自頂下下遍歷,我們遍歷到一個狀態p的時候,我們檢查該狀態的par節點的sum值,若sum已經超過k,則顯然這個狀態的所有子串均不滿足要求,,我們不妨把p的sum設為k+1,然後繼續遍歷下一個節點,否則,我們設tmp=sum,mi為p表示的子串的最小長度(p->par->val+1),ma為p所表示的子串的最大長度(p->val),由小到大開始枚舉每一個子串,即從p->po-mi+1到p->po-ma+1,若發現一個”好的“字符,則ans+=1,否則tmp++,若tmp超過了k,則設p->sum=k+1,跳過該狀態,否則ans+=1,最後設sum=tmp。繼續遍歷下一個狀態。最後我們輸出ans即可。代碼如下:
[cpp]
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define maxn 3010
#define Smaxn 26
using namespace std;
struct node
{
node *par,*go[Smaxn];
int po;
int sum;
int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
int vis[26];
void add(int c,int l,int po)
{
node *p=tail,*np=&que[tot++];
np->val=l;
np->po=po;
while(p&&p->go[c]==NULL)
p->go[c]=np,p=p->par;
if(p==NULL) np->par=root;
else
{
node *q=p->go[c];
if(p->val+1==q->val) np->par=q;
else
{
node *nq=&que[tot++];
*nq=*q;
nq->val=p->val+1;
np->par=q->par=nq;
while(p&&p->go[c]==q) p->go[c]=nq,p=p->par;
}
}
tail=np;
}
int c[maxn],len;
void init()
{
len=1;
tot=0;
memset(que,0,sizeof(que));
root=tail=&que[tot++];
}
void solve(int limit)
{
int i,j;
memset(c,0,sizeof(c));
for(i=0;i<tot;i++)
c[que[i].val]++;
for(i=1;i<len;i++)
c[i]+=c[i-1];
for(i=0;i<tot;i++)
top[--c[que[i].val]]=&que[i];
int sum=0;
for(i=1;i<tot;i++)
{
node *p=top[i];
if(p->par->sum>limit)
{
p->sum=limit+1;
continue;
}
int mi=p->par->val+1,ma=p->val,tmp=p->par->sum,po=p->po;
for(j=mi;j<=ma;j++)
{
if(vis[str[po-j+1]-'a'])
{
tmp++;
if(tmp>limit)
{
break;
}
else
sum++;
}
else
sum++;
}
p->sum=tmp;
}
printf("%d\n",sum);
}
int main()
{
//freopen("dd.txt","r",stdin);
scanf("%s",str);
int i,k,l=strlen(str);
init();
for(i=0;i<l;i++)
{
add(str[i]-'a',len++,i);
}
char tmp[26];
scanf("%s",tmp);
for(i=0;i<26;i++)
vis[i]=1-(tmp[i]-'0');
scanf("%d",&k);
solve(k);
return 0;
}
#include <iostream>
#include <stdio.h>
#include <string.h>
#include <algorithm>
#define maxn 3010
#define Smaxn 26
using namespace std;
struct node
{
node *par,*go[Smaxn];
int po;
int sum;
int val;
}*root,*tail,que[maxn],*top[maxn];
int tot;
char str[maxn>>1];
int vis[26];
void add(int c,int l,int po)
{
node *p=tail,*np=&que[tot++];
np->val=l;
np->po=po;
while(p&&p->go[c]==NULL)
p->go[c]=np,p=p->par;
if(p==NULL) np->par=root;
else
{
node *q=p->go[c];
if(p->val+1==q->val) np->par=q;
else
{
node *nq=&que[tot++];
*nq=*q;
nq->val=p->val+1;
np->par=q->par=nq;
while(p&&p->go[c]==q) p->go[c]=nq,p=p->par;
}
}
tail=np;
}
int c[maxn],len;
void init()
{
len=1;
tot=0;
memset(que,0,sizeof(que));
root=tail=&que[tot++];
}
void solve(int limit)
{
int i,j;
memset(c,0,sizeof(c));
for(i=0;i<tot;i++)
c[que[i].val]++;
for(i=1;i<len;i++)
c[i]+=c[i-1];
for(i=0;i<tot;i++)
top[--c[que[i].val]]=&que[i];
int sum=0;
for(i=1;i<tot;i++)
{
node *p=top[i];
if(p->par->sum>limit)
{
p->sum=limit+1;
continue;
}
int mi=p->par->val+1,ma=p->val,tmp=p->par->sum,po=p->po;
for(j=mi;j<=ma;j++)
{
if(vis[str[po-j+1]-'a'])
{
tmp++;
if(tmp>limit)
{
break;
}
else
sum++;
}
else
sum++;
}
p->sum=tmp;
}
printf("%d\n",sum);
}
int main()
{
//freopen("dd.txt","r",stdin);
scanf("%s",str);
int i,k,l=strlen(str);
init();
for(i=0;i<l;i++)
{
add(str[i]-'a',len++,i);
}
char tmp[26];
scanf("%s",tmp);
for(i=0;i<26;i++)
vis[i]=1-(tmp[i]-'0');
scanf("%d",&k);
solve(k);
return 0;
}