程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> HDU 1402 FFT 求 大數乘法

HDU 1402 FFT 求 大數乘法

編輯:C++入門知識

這題的數據量是5w, 也就是傳統意義上的n^2算法是不可取的。這裡就用到了FFT

 


FFT一般的作用就是使得多項式乘法的復雜度降到nlogn。利用FFT可以快速求出循環卷積。

那麼卷積又是什麼樣一個東西。

 往往是在連續的情形,
  兩個函數f(x),g(x)的卷積,是∫f(u)g(x-u)du
  當然,證明卷積的一些性質並不困難,比如交換,結合等等,但是對於卷積運算的來處,初學者就不甚了了。
  
  其實,從離散的情形看卷積,或許更加清楚,
  對於兩個序列f[n],g[n],一般可以將其卷積定義為s[x]= ∑f[k]g[x-k]
  
  卷積的一個典型例子,其實就是初中就學過的多項式相乘的運算,
  比如(x*x+3*x+2)(2*x+5)
  一般計算順序是這樣,
  (x*x+3*x+2)(2*x+5)
  = (x*x+3*x+2)*2*x+(x*x+3*x+2)*5
  = 2*x*x*x+3*2*x*x+2*2*x+ 5*x*x+3*5*x+10
  然後合並同類項的系數,
  2 x*x*x
  3*2+1*5 x*x
  2*2+3*5 x
  2*5
  ----------
  2*x*x*x+11*x*x+19*x+10
  
  實際上,從線性代數可以知道,多項式構成一個向量空間,其基底可選為
  {1,x,x*x,x*x*x,...}
  如此,則任何多項式均可與無窮維空間中的一個坐標向量相對應,
  如,(x*x+3*x+2)對應於
  (1 3 2),
  (2*x+5)對應於
  (2,5).
  
  線性空間中沒有定義兩個向量間的卷積運算,而只有加法,數乘兩種運算,而實際上,多項式的乘法,就無法在線性空間中說明.可見線性空間的理論多麼局限了.
  但如果按照我們上面對向量卷積的定義來處理坐標向量,
  (1 3 2)*(2 5)
  則有
  2 3 1
  _ _ 2 5
  --------
      2
  
  
  2 3 1
  _ 2 5
  -----
    6+5=11
  
  2 3 1
  2 5
  -----
  4+15 =19
  
  
  _ 2 3 1
  2 5
  -------
    10
  
   或者說,
  (1 3 2)*(2 5)=(2 11 19 10)
  
  回到多項式的表示上來,
  (x*x+3*x+2)(2*x+5)= 2*x*x*x+11*x*x+19*x+10
  
  似乎很神奇,結果跟我們用傳統辦法得到的是完全一樣的.
  換句話,多項式相乘,相當於系數向量的卷積.
  
  其實,琢磨一下,道理也很簡單,
  卷積運算實際上是分別求 x*x*x ,x*x,x,1的系數,也就是說,他把加法和求和雜合在一起做了。(傳統的辦法是先做乘法,然後在合並同類項的時候才作加法)
  以x*x的系數為例,得到x*x,或者是用x*x乘5,或者是用3x乘2x,也就是
  2 3 1
  _ 2 5
  -----
   6+5=11
  其實,這正是向量的內積.如此則,卷積運算,可以看作是一串內積運算.既然是一串內積運算,則我們可以試圖用矩陣表示上述過程。
  
  [ 2 3 1 0 0 0]
  [ 0 2 3 1 0 0]==A
  [ 0 0 2 3 1 0]
  [ 0 0 0 2 3 1]
  
  [0 0 2 5 0 0]' == x
  
  b= Ax=[ 2 11 19 10]'
  
  采用行的觀點看Ax,則b的每行都是一個內積。
  A的每一行都是序列[2 3 1]的一個移動位置。
  
  ---------
  
  顯然,在這個特定的背景下,我們知道,卷積滿足交換,結合等定律,因為,眾所周知的,多項式的乘法滿足交換律,結合律.在一般情形下,其實也成立.
  
  在這裡,我們發現多項式,除了構成特定的線性空間外,基與基之間還存在某種特殊的聯系,正是這種聯系,給予多項式空間以特殊的性質.
  
  在學向量的時候,一般都會舉這個例子,甲有三個蘋果,5個橘子,乙有5個蘋果,三個橘子,則共有幾個蘋果,橘子。老師反復告誡,橘子就是橘子,蘋果就是蘋果,可不能混在一起。所以有(3,5)+(5,3)=(8,8).是的,橘子和蘋果無論怎麼加,都不會出什麼問題的,但是,如果考慮橘子乘橘子,或者橘子乘蘋果,這問題就不大容易說清了。
  
  又如復數,如果僅僅定義復數為數對(a,b),僅僅在線性空間的層面看待C2,那就未免太簡單了。實際上,只要加上一條(a,b)*(c,d)=(ac-bd,ad+bc)
  則情況馬上改觀,復變函數的內容多麼豐富多彩,是眾所周知的。
  
  另外,回想信號處理裡面的一條基本定理,頻率域的乘積,相當於時域或空域信號的卷積.恰好跟這裡的情形完全對等.這後面存在什麼樣的隱態聯系,需要繼續參詳.
  
  從這裡看,高等的卷積運算其實不過是一種初等的運算的抽象而已.中學學過的數學裡面,其實還蘊涵著許多高深的內容(比如交換代數)。溫故而知新,斯言不謬.
  
  其實這道理一點也不復雜,人類繁衍了多少萬年了,但過去n多年,人們只知道男女媾精,乃能繁衍後代。精子,卵子的發現,生殖機制的研究,也就是最近多少年的事情。
  
  孔子說,道在人倫日用中,看來我們應該多用審視的眼光看待周圍,乃至自身,才能知其然,而知其所以然。

 

 

----------------------------------------------------------完畢------------------------------

 


然後我們就知道卷積大概的作用了。

那麼FFT本來是信號裡面的東西,而我沒學過信號。 所以看的也不怎麼懂。

大概就是對離散的信號,先將其轉變為一些正弦函數,然後這些正弦函數疊加能構成這個離散信號,但是這些正弦函數易於處理。處理完之後就可以再轉變回來。

兩個過程叫做DFT和IDFT。

 

 

 

對於本道題。意義就很明顯了。

可以把兩個大整數相乘看做是多項式乘法。

最後求出各系數後再進位即可

 


代碼如下、


[cpp]
#include <iostream>  
#include <cstdio>  
#include <algorithm>  
#include <cstring>  
#include <cmath>  
#include <map>  
#include <queue>  
#include <set>  
#include <vector>  
using namespace std; 
#define L(x) (1 << (x))  
const double PI = acos(-1.0); 
const int Maxn = 133015; 
double ax[Maxn], ay[Maxn], bx[Maxn], by[Maxn]; 
char sa[Maxn/2],sb[Maxn/2]; 
int sum[Maxn]; 
int x1[Maxn],x2[Maxn]; 
int revv(int x, int bits) 

    int ret = 0; 
    for (int i = 0; i < bits; i++) 
    { 
        ret <<= 1; 
        ret |= x & 1; 
        x >>= 1; 
    } 
    return ret; 

void fft(double * a, double * b, int n, bool rev) 

    int bits = 0; 
    while (1 << bits < n) ++bits; 
    for (int i = 0; i < n; i++) 
    { 
        int j = revv(i, bits); 
        if (i < j) 
            swap(a[i], a[j]), swap(b[i], b[j]); 
    } 
    for (int len = 2; len <= n; len <<= 1) 
    { 
        int half = len >> 1; 
        double wmx = cos(2 * PI / len), wmy = sin(2 * PI / len); 
        if (rev) wmy = -wmy; 
        for (int i = 0; i < n; i += len) 
        { 
            double wx = 1, wy = 0; 
            for (int j = 0; j < half; j++) 
            { 
                double cx = a[i + j], cy = b[i + j]; 
                double dx = a[i + j + half], dy = b[i + j + half]; 
                double ex = dx * wx - dy * wy, ey = dx * wy + dy * wx; 
                a[i + j] = cx + ex, b[i + j] = cy + ey; 
                a[i + j + half] = cx - ex, b[i + j + half] = cy - ey; 
                double wnx = wx * wmx - wy * wmy, wny = wx * wmy + wy * wmx; 
                wx = wnx, wy = wny; 
            } 
        } 
    } 
    if (rev) 
    { 
        for (int i = 0; i < n; i++) 
            a[i] /= n, b[i] /= n; 
    } 

int solve(int a[],int na,int b[],int nb,int ans[]) 

    int len = max(na, nb), ln; 
    for(ln=0; L(ln)<len; ++ln); 
    len=L(++ln); 
    for (int i = 0; i < len ; ++i) 
    { 
        if (i >= na) ax[i] = 0, ay[i] =0; 
        else ax[i] = a[i], ay[i] = 0; 
    } 
    fft(ax, ay, len, 0); 
    for (int i = 0; i < len; ++i) 
    { 
        if (i >= nb) bx[i] = 0, by[i] = 0; 
        else bx[i] = b[i], by[i] = 0; 
    } 
    fft(bx, by, len, 0); 
    for (int i = 0; i < len; ++i) 
    { 
        double cx = ax[i] * bx[i] - ay[i] * by[i]; 
        double cy = ax[i] * by[i] + ay[i] * bx[i]; 
        ax[i] = cx, ay[i] = cy; 
    } 
    fft(ax, ay, len, 1); 
    for (int i = 0; i < len; ++i) 
        ans[i] = (int)(ax[i] + 0.5); 
    return len; 

 
int main() 

    int l1,l2,l; 
    int i; 
    while(gets(sa)) 
    { 
        gets(sb); 
        memset(sum, 0, sizeof(sum)); 
        l1 = strlen(sa); 
        l2 = strlen(sb); 
        for(i = 0; i < l1; i++) 
            x1[i] = sa[l1 - i - 1]-'0'; 
        for(i = 0; i < l2; i++) 
            x2[i] = sb[l2-i-1]-'0'; 
        l = solve(x1, l1, x2, l2, sum); 
        for(i = 0; i<l || sum[i] >= 10; i++) // 進位  
        { 
            sum[i + 1] += sum[i] / 10; 
            sum[i] %= 10; 
        } 
        l = i; 
        while(sum[l] <= 0 && l>0)    l--; // 檢索最高位  
        for(i = l; i >= 0; i--)    putchar(sum[i] + '0'); // 倒序輸出  
        putchar('\n'); 
    } 
    return 0; 

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
#define L(x) (1 << (x))
const double PI = acos(-1.0);
const int Maxn = 133015;
double ax[Maxn], ay[Maxn], bx[Maxn], by[Maxn];
char sa[Maxn/2],sb[Maxn/2];
int sum[Maxn];
int x1[Maxn],x2[Maxn];
int revv(int x, int bits)
{
    int ret = 0;
    for (int i = 0; i < bits; i++)
    {
        ret <<= 1;
        ret |= x & 1;
        x >>= 1;
    }
    return ret;
}
void fft(double * a, double * b, int n, bool rev)
{
    int bits = 0;
    while (1 << bits < n) ++bits;
    for (int i = 0; i < n; i++)
    {
        int j = revv(i, bits);
        if (i < j)
            swap(a[i], a[j]), swap(b[i], b[j]);
    }
    for (int len = 2; len <= n; len <<= 1)
    {
        int half = len >> 1;
        double wmx = cos(2 * PI / len), wmy = sin(2 * PI / len);
        if (rev) wmy = -wmy;
        for (int i = 0; i < n; i += len)
        {
            double wx = 1, wy = 0;
            for (int j = 0; j < half; j++)
            {
                double cx = a[i + j], cy = b[i + j];
                double dx = a[i + j + half], dy = b[i + j + half];
                double ex = dx * wx - dy * wy, ey = dx * wy + dy * wx;
                a[i + j] = cx + ex, b[i + j] = cy + ey;
                a[i + j + half] = cx - ex, b[i + j + half] = cy - ey;
                double wnx = wx * wmx - wy * wmy, wny = wx * wmy + wy * wmx;
                wx = wnx, wy = wny;
            }
        }
    }
    if (rev)
    {
        for (int i = 0; i < n; i++)
            a[i] /= n, b[i] /= n;
    }
}
int solve(int a[],int na,int b[],int nb,int ans[])
{
    int len = max(na, nb), ln;
    for(ln=0; L(ln)<len; ++ln);
    len=L(++ln);
    for (int i = 0; i < len ; ++i)
    {
        if (i >= na) ax[i] = 0, ay[i] =0;
        else ax[i] = a[i], ay[i] = 0;
    }
    fft(ax, ay, len, 0);
    for (int i = 0; i < len; ++i)
    {
        if (i >= nb) bx[i] = 0, by[i] = 0;
        else bx[i] = b[i], by[i] = 0;
    }
    fft(bx, by, len, 0);
    for (int i = 0; i < len; ++i)
    {
        double cx = ax[i] * bx[i] - ay[i] * by[i];
        double cy = ax[i] * by[i] + ay[i] * bx[i];
        ax[i] = cx, ay[i] = cy;
    }
    fft(ax, ay, len, 1);
    for (int i = 0; i < len; ++i)
        ans[i] = (int)(ax[i] + 0.5);
    return len;
}

int main()
{
    int l1,l2,l;
    int i;
    while(gets(sa))
    {
        gets(sb);
        memset(sum, 0, sizeof(sum));
        l1 = strlen(sa);
        l2 = strlen(sb);
        for(i = 0; i < l1; i++)
            x1[i] = sa[l1 - i - 1]-'0';
        for(i = 0; i < l2; i++)
            x2[i] = sb[l2-i-1]-'0';
        l = solve(x1, l1, x2, l2, sum);
        for(i = 0; i<l || sum[i] >= 10; i++) // 進位
        {
            sum[i + 1] += sum[i] / 10;
            sum[i] %= 10;
        }
        l = i;
        while(sum[l] <= 0 && l>0)    l--; // 檢索最高位
        for(i = l; i >= 0; i--)    putchar(sum[i] + '0'); // 倒序輸出
        putchar('\n');
    }
    return 0;
}

 

 

 

然後模板來一發。

 

 

[cpp]
#define L(x) (1 << (x))  
const double PI = acos(-1.0); 
const int Maxn = 133015; 
double ax[Maxn], ay[Maxn], bx[Maxn], by[Maxn]; 
int revv(int x, int bits) 

    int ret = 0; 
    for (int i = 0; i < bits; i++) 
    { 
        ret <<= 1; 
        ret |= x & 1; 
        x >>= 1; 
    } 
    return ret; 

void fft(double * a, double * b, int n, bool rev) 

    int bits = 0; 
    while (1 << bits < n) ++bits; 
    for (int i = 0; i < n; i++) 
    { 
        int j = revv(i, bits); 
        if (i < j) 
            swap(a[i], a[j]), swap(b[i], b[j]); 
    } 
    for (int len = 2; len <= n; len <<= 1) 
    { 
        int half = len >> 1; 
        double wmx = cos(2 * PI / len), wmy = sin(2 * PI / len); 
        if (rev) wmy = -wmy; 
        for (int i = 0; i < n; i += len) 
        { 
            double wx = 1, wy = 0; 
            for (int j = 0; j < half; j++) 
            { 
                double cx = a[i + j], cy = b[i + j]; 
                double dx = a[i + j + half], dy = b[i + j + half]; 
                double ex = dx * wx - dy * wy, ey = dx * wy + dy * wx; 
                a[i + j] = cx + ex, b[i + j] = cy + ey; 
                a[i + j + half] = cx - ex, b[i + j + half] = cy - ey; 
                double wnx = wx * wmx - wy * wmy, wny = wx * wmy + wy * wmx; 
                wx = wnx, wy = wny; 
            } 
        } 
    } 
    if (rev) 
    { 
        for (int i = 0; i < n; i++) 
            a[i] /= n, b[i] /= n; 
    } 

int solve(int a[],int na,int b[],int nb,int ans[]) //兩個數組求卷積,有時ans數組要開成long long  

    int len = max(na, nb), ln; 
    for(ln=0; L(ln)<len; ++ln); 
    len=L(++ln); 
    for (int i = 0; i < len ; ++i) 
    { 
        if (i >= na) ax[i] = 0, ay[i] =0; 
        else ax[i] = a[i], ay[i] = 0; 
    } 
    fft(ax, ay, len, 0); 
    for (int i = 0; i < len; ++i) 
    { 
        if (i >= nb) bx[i] = 0, by[i] = 0; 
        else bx[i] = b[i], by[i] = 0; 
    } 
    fft(bx, by, len, 0); 
    for (int i = 0; i < len; ++i) 
    { 
        double cx = ax[i] * bx[i] - ay[i] * by[i]; 
        double cy = ax[i] * by[i] + ay[i] * bx[i]; 
        ax[i] = cx, ay[i] = cy; 
    } 
    fft(ax, ay, len, 1); 
    for (int i = 0; i < len; ++i) 
        ans[i] = (int)(ax[i] + 0.5); 
    return len; 

int solve(long long a[], int na, int ans[]) //自己跟自己求卷積,有時候ans數組要開成long long  

    int len = na, ln; 
    for(ln = 0; L(ln) < na; ++ln); 
    len=L(++ln); 
    for(int i = 0; i < len; ++i) 
    { 
        if (i >= na) ax[i] = 0, ay[i] = 0; 
        else ax[i] = a[i], ay[i] = 0; 
    } 
    fft(ax, ay, len, 0); 
    for(int i=0; i<len; ++i) 
    { 
        double cx = ax[i] * ax[i] - ay[i] * ay[i]; 
        double cy = 2 * ax[i] * ay[i]; 
        ax[i] = cx, ay[i] = cy; 
    } 
    fft(ax, ay, len, 1); 
 
    for(int i=0; i<len; ++i) 
        ans[i] = ax[i] + 0.5; 
    return len; 

#define L(x) (1 << (x))
const double PI = acos(-1.0);
const int Maxn = 133015;
double ax[Maxn], ay[Maxn], bx[Maxn], by[Maxn];
int revv(int x, int bits)
{
    int ret = 0;
    for (int i = 0; i < bits; i++)
    {
        ret <<= 1;
        ret |= x & 1;
        x >>= 1;
    }
    return ret;
}
void fft(double * a, double * b, int n, bool rev)
{
    int bits = 0;
    while (1 << bits < n) ++bits;
    for (int i = 0; i < n; i++)
    {
        int j = revv(i, bits);
        if (i < j)
            swap(a[i], a[j]), swap(b[i], b[j]);
    }
    for (int len = 2; len <= n; len <<= 1)
    {
        int half = len >> 1;
        double wmx = cos(2 * PI / len), wmy = sin(2 * PI / len);
        if (rev) wmy = -wmy;
        for (int i = 0; i < n; i += len)
        {
            double wx = 1, wy = 0;
            for (int j = 0; j < half; j++)
            {
                double cx = a[i + j], cy = b[i + j];
                double dx = a[i + j + half], dy = b[i + j + half];
                double ex = dx * wx - dy * wy, ey = dx * wy + dy * wx;
                a[i + j] = cx + ex, b[i + j] = cy + ey;
                a[i + j + half] = cx - ex, b[i + j + half] = cy - ey;
                double wnx = wx * wmx - wy * wmy, wny = wx * wmy + wy * wmx;
                wx = wnx, wy = wny;
            }
        }
    }
    if (rev)
    {
        for (int i = 0; i < n; i++)
            a[i] /= n, b[i] /= n;
    }
}
int solve(int a[],int na,int b[],int nb,int ans[]) //兩個數組求卷積,有時ans數組要開成long long
{
    int len = max(na, nb), ln;
    for(ln=0; L(ln)<len; ++ln);
    len=L(++ln);
    for (int i = 0; i < len ; ++i)
    {
        if (i >= na) ax[i] = 0, ay[i] =0;
        else ax[i] = a[i], ay[i] = 0;
    }
    fft(ax, ay, len, 0);
    for (int i = 0; i < len; ++i)
    {
        if (i >= nb) bx[i] = 0, by[i] = 0;
        else bx[i] = b[i], by[i] = 0;
    }
    fft(bx, by, len, 0);
    for (int i = 0; i < len; ++i)
    {
        double cx = ax[i] * bx[i] - ay[i] * by[i];
        double cy = ax[i] * by[i] + ay[i] * bx[i];
        ax[i] = cx, ay[i] = cy;
    }
    fft(ax, ay, len, 1);
    for (int i = 0; i < len; ++i)
        ans[i] = (int)(ax[i] + 0.5);
    return len;
}
int solve(long long a[], int na, int ans[]) //自己跟自己求卷積,有時候ans數組要開成long long
{
    int len = na, ln;
    for(ln = 0; L(ln) < na; ++ln);
    len=L(++ln);
    for(int i = 0; i < len; ++i)
    {
        if (i >= na) ax[i] = 0, ay[i] = 0;
        else ax[i] = a[i], ay[i] = 0;
    }
    fft(ax, ay, len, 0);
    for(int i=0; i<len; ++i)
    {
        double cx = ax[i] * ax[i] - ay[i] * ay[i];
        double cy = 2 * ax[i] * ay[i];
        ax[i] = cx, ay[i] = cy;
    }
    fft(ax, ay, len, 1);

    for(int i=0; i<len; ++i)
        ans[i] = ax[i] + 0.5;
    return len;
}

 

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved