題目大意:
給一個n個數的序列a1, a2, ..., an ,這些數的范圍是0~n-1, 可以把前面m個數移動到後面去,形成新序列:
a1, a2, ..., an-1, an (where m = 0 - the initial seqence)
a2, a3, ..., an, a1 (where m = 1)
a3, a4, ..., an, a1, a2 (where m = 2)
...
an, a1, a2, ..., an-1 (where m = n-1)
求這些序列中,逆序數最少的是多少?
分析與總結:
(1) 初看這題時,覺得眼熟,於是翻看了交題記錄,原來去年做過 = =,不過那時是直接暴力求出逆序數水過的...
然後就想這樣用線段樹來優化。所謂的求逆序數,其實就是序列中的每一個數,它前面的比他大的數的數量之和。
用線段樹記錄下各個數,區間的值【a,b】表示數字a~b的已經出現了多少次,所以對於ai,只需要查詢【ai, n】有多少個(ai之前的比ai大的有多少個),就是代表ai有多少個逆序數了。
(2) 求出a1, a2, ..., an-1, an的逆序數之後,就可以遞推求出其他序列的逆序數。 假設要把a1移動到an之後,那麼我們把這個過程拆分成兩步:
1. 把a1去除掉。通過觀察可以發現,(a1-1)是0~n-1中比a1小的數字的個數,由於a1在序列的第一個所以a1之後共有(a1-1)個比a1小,所以形成了(a1-1)對逆序數,當去除掉a1時,原序列的逆序數總數也就減少了(a1-1)個逆序數。
2. 把a1加到an之後。0~n-1中,比a1大的數共有(n-a1)個數,由於a1現在在最後一個,也就是它前面共有(n-a1)個數比它大,即增加了(n-a1)對逆序數。
綜合1,2兩步, 設原序列逆序數為sum, 當把原序列第一個移動到最後位置時,逆序數變為:sum = sum-(ai-1)+(n-ai);
代碼:
[cpp]
#include<iostream>
#include<cstdio>
#include<cstring>
#define lson(x) (x<<1)
#define rson(x) (lson(x)|1)
using namespace std;
const int MAX_NODE = 5005 << 2;
int arr[MAX_NODE];
struct node{
int left, right;
int num;
int mid(){return (left+right)>>1;}
bool buttom(){return left==right;}
};
class SegTree{
public:
void build(int cur,int left,int right){
t[cur].left = left;
t[cur].right = right;
if(left == right){
t[cur].num = 0;
return;
}
int m = t[cur].mid();
build(lson(cur),left,m);
build(rson(cur),m+1,right);
push_up(cur);
}
void update(int cur,int data){
++t[cur].num;
if(t[cur].buttom()){
return;
}
int m = t[cur].mid();
if(data <= m)
update(lson(cur),data);
else
update(rson(cur),data);
}
int query(int cur,int left,int right){
if(t[cur].left==left && t[cur].right==right){
return t[cur].num;
}
int m=t[cur].mid();
if(right <= m)
return query(lson(cur),left,right);
else if(left > m)
return query(rson(cur),left,right);
else
return query(lson(cur),left,m)+query(rson(cur),m+1,right);
}
private:
void push_up(int cur){
t[cur].num = t[lson(cur)].num+t[rson(cur)].num;
}
node t[MAX_NODE];
};
SegTree st;
int main(){
int n,x;
while(~scanf("%d",&n)){
st.build(1,1,n);
int sum=0;
for(int i=0; i<n; ++i){
scanf("%d",&arr[i]);
++arr[i];
int tmp = st.query(1,arr[i],n);
sum += tmp;
st.update(1,arr[i]);
}
int _min = sum;
for(int i=0; i<n-1; ++i){
sum = sum-(arr[i]-1)+(n-arr[i]);
if(sum < _min) _min=sum;
}
printf("%d\n",_min);
}
return 0;
}