題目大意:
兩個操作
1 id op 把id的位置+op
2 id op 查詢在【id,op】之間的所有的數的差
思路:
關鍵是pushup函數。
自己退一下會發現,跟區間的總和,區間的節點個數有關。
比如如果左區間是 1 2 的話
右區間來一個 9
那麼
就要加上
9-1+9-2
#include#include #include #include #include #define inf 0x3f3f3f3f #define maxn 222222 #define keyTree (ch[ch[root][1]][0]) //當把l-1放在根節點 r+1放在根節點的右子樹 //那麼根節點的右子樹的左子樹就是[l,r] 這個區間的所有值 using namespace std; typedef long long LL; int S[maxn],que[maxn],ch[maxn][2],pre[maxn],siz[maxn]; int root,top1,top2; LL ans[maxn],val[maxn],a[maxn],b[maxn]; LL sum[maxn]; set tab; void Treaval(int x) { if(x) { Treaval(ch[x][0]); printf("%I64d ",val[x]); Treaval(ch[x][1]); } } void debug() { // printf("root=%d\n",root); Treaval(root); puts(""); } void New(int &x,int PRE,LL v) { if(top2)x=S[--top2]; else x=++top1; ch[x][0]=ch[x][1]=0; siz[x]=1; pre[x]=PRE; /*special*/ sum[x]=v; ans[x]=0; val[x]=v; } void pushup(int x)/*special*/ { siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1; sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+val[x]; ans[x]=ans[ch[x][0]]+ans[ch[x][1]]+siz[ch[x][0]]*val[x]-sum[ch[x][0]]+sum[ch[x][1]]-val[x]*siz[ch[x][1]]+siz[ch[x][0]]*sum[ch[x][1]]-siz[ch[x][1]]*sum[ch[x][0]]; } void pushdown(int x)/*special*/ { } void build(int &x,int s,int e,int f) { if(s>e)return; int mid=(s+e)>>1; New(x,f,a[mid]); if(s mid)build(ch[x][1],mid+1,e,x); pushup(x); } void Rotate(int x,int kind) { int y=pre[x]; pushdown(x); pushdown(y); ch[y][!kind]=ch[x][kind]; pre[ch[x][kind]]=y; if(pre[y])ch[pre[y]][ch[pre[y]][1]==y]=x; pre[x]=pre[y]; ch[x][kind]=y; pre[y]=x; pushup(y); } void Splay(int x,int goal) { pushdown(x); while(pre[x]!=goal) { if(pre[pre[x]]==goal) Rotate(x,ch[pre[x]][0]==x); else { int y=pre[x]; int kind=ch[pre[y]][0]==y; if(ch[y][kind]==x){ Rotate(x,!kind); Rotate(x,kind); } else { Rotate(y,kind); Rotate(x,kind); } } } pushup(x); if(goal==0)root=x; } void RotateTo(int k,int goal) { int r=root; pushdown(r); while(siz[ch[r][0]]!=k) { if(k val[t])return find(x,ch[t][1]); else return find(x,ch[t][0]); } set ::iterator it; int add(int x,LL Num,int pos) { if(Num<=val[x]) { if(ch[x][0]==0) { Splay(x,0); int S=siz[ch[root][0]]; RotateTo(S-1,0); RotateTo(S,root); New(keyTree,ch[root][1],Num); pushup(ch[root][1]); pushup(root); } else add(ch[x][0],Num,pos); } else { if(ch[x][1]==0) { Splay(x,0); int S=siz[ch[root][0]]; RotateTo(S,0); RotateTo(S+1,root); New(keyTree,ch[root][1],Num); pushup(ch[root][1]); pushup(root); } else add(ch[x][1],Num,siz[ch[x][0]]+1+pos); } } int main() { tab.clear(); int n; scanf("%d",&n); init(n); //debug(); int m; scanf("%d",&m); while(m--) { int op; int l,r; scanf("%d%d%d",&op,&l,&r); if(op==1) { it=tab.lower_bound(b[l-1]); //printf("*it = %I64d\n",*it); int pos=find(*it,root); //printf("---%I64d\n",val[pos]); Splay(pos,0); int Spos=siz[ch[root][0]]; RotateTo(Spos-1,0); RotateTo(Spos+1,root); //erase(keyTree); erase(keyTree); pushup(ch[root][1]); pushup(root); tab.erase(it); b[l-1]+=r; tab.insert(b[l-1]); add(root,b[l-1],0); } else { it=tab.lower_bound(l); it--; //printf("%I64d\n",*it); //printf("%I64d\n",val[]); Splay(find(*it,root),0); it=tab.upper_bound(r); Splay(find(*it,root),root); printf("%I64d\n",ans[keyTree]); } //debug(); } return 0; }