★ 引子
上一篇文章講了 Comba 乘法的原理,這次來講講如何實現。為了方便移植和充分發揮不同平台下的性能,暫時用了三種不同的實現方式:
1、單雙精度變量都有的情況。
2、只有單精度變量的情況。
3、可以使用內聯匯編的情況。
前面已經介紹了 Comba 乘法的原理和實現思路,為了方便,再把它貼到這裡:
計算 c = a * b,c0,c1,c2 為單精度變量。
1. 增加 c 到所需要的精度,並且令 c = 0,c->used = a->used + b->used。
2. 令 c2 = c1 = c0 = 0。
3. 對於 i 從 0 到 c->used - 1 循環,執行如下操作:
3.1 ty = BN_MIN(i, b->used - 1)
3.2 tx = i - ty
3.3 k = BN_MIN(a->used - tx, ty + 1)
3.4 三精度變量右移一個數位:(c2 || c1 || c0) = (0 || c2 || c1)
3.5 對於 j 從 0 到 k - 1 之間執行循環,計算:
(c2 || c1 || c0) = (c2 || c1 || c0) + a(tx + j) * b(ty - j)
3.6 c(i) = c0
4. 壓縮多余位,返回 c。
上面所說的三種不同實現方式,主要體現在 3.5 中的循環。 Comba 乘法的實現代碼如下:(暫時不考慮 x = x * y 這種輸入和輸出是同一個變量的情況)
static void bn_mul_comba(bignum *z, const bignum *x, const bignum *y) { bn_digit c0, c1, c2; bn_digit *px, *py, *pz; size_t nc, i, j, k, tx, ty; pz = z->dp; nc = z->used; c0 = c1 = c2 = 0; for(i = 0; i < nc; i++) { ty = BN_MIN(i, y->used - 1); tx = i - ty; k = BN_MIN(x->used - tx, ty + 1); px = x->dp + tx; py = y->dp + ty; c0 = c1; c1 = c2; c2 = 0U; //Comba 32 for(j = k; j >= 32; j -= 32) { COMBA_INIT COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_STOP } //Comba 16 for(; j >= 16; j -= 16) { COMBA_INIT COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_STOP } //Comba 8 for(; j >= 8; j -= 8) { COMBA_INIT COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_STOP } //Comba 4 for(; j >= 4; j -= 4) { COMBA_INIT COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_MULADDC COMBA_STOP } //Comba 1 for(; j > 0; j--) { COMBA_INIT COMBA_MULADDC COMBA_STOP } *pz++ = c0; } //for i bn_clamp(z); }
在 3.5 步中的內循環,乘法的計算按照 1,4,8,16,32 的步進展開,展開的過程需要比較大的空間,但是減少了循環控制的開銷,所以可以節省大量時間。執行乘法的關鍵代碼都定義在 COMBA_INIT,COMBA_MULADDC 和 COMBA_STOP 這三個宏內。COMBA_INIT 用於變量定義或者初始化,COMBA_MULADDC 用於計算單精度乘法和累加,COMBA_STOP 用於存儲當前計算結果。
★ 單雙精度變量都有的情況
這種情況下,因為 bn_digit 和 bn_udbl 同時有定義,所以是最容易實現的一種。三個宏的定義如下:
#define COMBA_INIT \ { \ bn_udbl r; \ #define COMBA_MULADDC \ \ r = (bn_udbl)(*px++) * (*py--) + c0; \ c0 = (bn_digit)r; \ r = (bn_udbl)c1 + (r >> biL); \ c1 = (bn_digit)r; \ c2 += (bn_digit)(r >> biL); \ #define COMBA_STOP \ }
在乘法器中,定義雙精度變量 r 存儲單精度乘法結果。在計算中,先計算單精度乘法,然後在與三精度變量 c2||c1||c0 相加。因為 c0,c1,c2 都是單精度數,為了避免溢出,先把累加計算的結果放到 r 上,然後提取低半部分作為當前數位的值。這裡暫時不需要存儲計算結果,所以 COMBA_STOP 沒有什麼操作。
★ 只有單精度變量的情況
這種情況下,bn_udbl 沒有定義,每個數被拆成高半部分和低半部分,要執行 4 次單精度乘法,所以計算會比較復雜。
#define COMBA_INIT \ { \ bn_digit a0, a1, b0, b1; \ bn_digit t0, t1, r0, r1; \ #define COMBA_MULADDC \ \ a0 = (*px << biLH) >> biLH; \ b0 = (*py << biLH) >> biLH; \ a1 = *px++ >> biLH; \ b1 = *py-- >> biLH; \ r0 = a0 * b0; \ r1 = a1 * b1; \ t0 = a1 * b0; \ t1 = a0 * b1; \ r1 += (t0 >> biLH); \ r1 += (t1 >> biLH); \ t0 <<= biLH; \ t1 <<= biLH; \ r0 += t0; \ r1 += (r0 < t0); \ r0 += t1; \ r1 += (r0 < t1); \ c0 += r0; \ c1 += (c0 < r0); \ c1 += r1; \ c2 += (c1 < r1); \ #define COMBA_STOP \ }
a1,a0 分別存儲 x 的高半部分和低半部分,b1,b0 分別存儲 y 的高半部分和低半部分,這兩個操作可以通過將 x 和 y 向左向右移動半個數位得到。biLH 前面提到過,這裡在說一下:biLH 是每個數位的比特大小的一半。上面這個乘法計算比較復雜,所以我先講講原理。
設 f(x),g(x),h(x) 分別表示單精度數 a,b 和 c,要計算 c = a * b。
a 和 b 用 f(x) 和 g(x) 表示:f(x) = a1 * x + a0,g(x) = b1 * x + b0,其中這裡的 x 就是 2 ^ (n / 2)。
那麼 h(x) = f(x) * g(x) = a1 * b1 * (x ^ 2) + (a1 * b0 + a0 * b1) * x + a0 * b0。這裡默認 a0,a1,b0,b1 都是半精度的變量,所以他們的乘積可以用一個單精度變量存儲。
令 t0 = a1 * b0 = p * x + q,t1 = a0 * b1 = r * x + s,r0 = a0 * b0,r1 = a1 * b1,其中 p,q,r,s 是半精度數,那麼 h(x) 可以改寫成:
h(x) = a1 * b1 * (x ^ 2) + (p * x + q + r * x + s) * x + a0 * b0
= a1 * b1 * (x ^ 2) + ((p + r) * x + (q + s)) * x + a0 * b0
= a1 * b1 * (x ^ 2) + (p + r) * (x ^ 2) + (q + s) * x + a0 * b0
= (a1 * b1 + p + r) * (x ^ 2) + (q + s) * x + a0 * b0
= (r1 + p + r) * (x ^ 2) + (q + s) * x + r0
所以,要計算 (r1 + p + r) * (x ^ 2),需要提取 t0,t1 的高半部分和 r1 相加(因為 r1 和 p,r 都是同類項,所以直接計算 x^2 的系數即可)。計算剩余項 (q + s) * x 和 r0 的和,需要將 t0 和 t1 的低半部分提取出來,左移半個數位後和 r0 相加(r0 是單精度數,而 q 和 s 是半個精度的,乘以 x 要左移半個數位)。計算 (q + s) * x + r0 後結果可能會溢出,表明產生了進位,所以還要把進位要傳遞到 (r1 + p + r) * (x ^ 2),最終乘積 c 就可以用 r0,r1 組成的雙精度變量 r1||r0 表示。最後將乘積 c 和三精度變量 c2||c1||c0 累加即可。
★ 使用內聯匯編的情況
注:本小節涉及到 x86 的匯編,如果對匯編不了解的話,可以跳過,免得浪費心情。對於 C 的內聯匯編,細節暫時不說了,參考下面兩篇文章:
【GCC 和 VC 的內聯匯編】 https://github.com/1184893257/simplelinux/blob/master/inlineasm.md
【Linux 中 x86 的內聯匯編】 http://www.ibm.com/developerworks/cn/linux/sdk/assemble/inline/
如果編譯環境可以使用內聯匯編,則使用匯編指令執行乘法會加快計算速度。由於我只接觸過 x86 平台的匯編,所以暫時只考慮 x86 平台,ARM,PowerPC 或者 MIPS 之類的架構暫時不管。 在 x86 環境下,需要考慮到 GCC 和 VC 環境下的內聯匯編,畢竟他們的語法差別很大。
A. VC 環境下的內聯匯編:(和 GCC 的比起來要簡單得多)
#define COMBA_INIT \ { \ __asm mov esi, px \ __asm mov edi, py \ #define COMBA_MULADDC \ \ __asm lodsd \ __asm mov ebx, [edi] \ __asm mul ebx \ __asm add c0, eax \ __asm adc c1, edx \ __asm adc c2, 0 \ __asm sub edi, 4 \ #define COMBA_STOP \ \ __asm mov px, esi \ __asm mov py, edi \ }
1. 首先把指針 px 和 py 的地址用 MOV 指令分別送到源變址寄存器 ESI 和目的變址寄存器 EDI。
2. 執行單精度乘法和累加:
A. 執行 LODSD 指令,把 ESI 寄存器指向的數據段中某單元的內容(也就是 x 的某個數位的值)送到 EAX 寄存器中,並根據方向標志和數據類型修改 ESI 寄存器的內容,具體的操作是將 ESI 中 的地址增加 4,相當於 C 中的 px++。一般情況下,方向標志 DF = 0, ESI 中的地址增加。如果 DF = 1,要執行 CLD 指令把 DF 置為 0,不過一般情況下不會出現這種情況的。
B. 用 MOV 指令把 EDI 寄存器所指向的數據段中某單元的內容(也就是 y 的某個數位的值)送到 EBX 寄存器中。
C. 執行 MUL 指令執行單精度乘法:(EDX,EAX)= EAX * EBX,計算 EAX 和 EBX 的乘積,結果的高位放在 EDX 寄存器中,結果的低位放在 EAX 寄存器中。
D. 執行 ADD 加法指令,將乘積的低位累加到三精度變量的最低位 c0 上。
E. 執行 ADC 帶進位的加法指令,將乘積的高位和上一步加法產生的進位累加到三精度變量的第二個數位 c1 上。
F. 執行 ADC 帶進位的加法指令,將剩余的進位加到三精度變量的最高位 c2 上。
G. 執行 SUB 減法指令,將 EDI 中的地址往前挪 4 字節,相當於 C 中的 py--。
3. 存儲計算結果:用 MOV 指令把 ESI 和 EDI 中的地址送回到指針 px 和 py 中,之所以要做這一步,是因為在循環中,循環控制可能會修改 CPU 寄存器的值,如果在計算結束後不存儲 ESI 和 EDI 的地址值,那麼地址可能就會因為寄存器被修改而丟失。
B. GCC 環境下的內聯匯編:(比較復雜)
#define COMBA_INIT \ { \ asm \ ( \ "movl %5, %%esi \n\t" \ "movl %6, %%edi \n\t" \ #define COMBA_MULADDC \ \ "lodsl \n\t" \ "movl (%%edi), %%ebx \n\t" \ "mull %%ebx \n\t" \ "addl %%eax, %2 \n\t" \ "adcl %%edx, %3 \n\t" \ "adcl $0, %4 \n\t" \ "subl $4, %%edi \n\t" \ #define COMBA_STOP \ \ "movl %%esi, %0 \n\t" \ "movl %%edi, %1 \n\t" \ :"=m"(px),"=m"(py),"=m"(c0),"=m"(c1),"=m"(c2) \ :"m"(px),"m"(py) \ :"%eax","%ebx","%ecx","%edx","%esi","%edi" \ ); \ }
這段 GCC 的內聯匯編指令和上面的 VC 內聯匯編指令操作是一樣的,只是語法不同而已,具體的語法請自行參考上面的資料或者 Google 搜索 :)
★ 總結
Comba 乘法大概就講這麼多吧。由於 Comba 乘法的時間復雜度仍然是 O(n^2),所以當輸入的規模 n 越大時,所需的時間仍然會急劇增加。下一篇文章將講講如何使用分治的方式降低乘法的時間復雜度。
【回到本系列目錄】
版權聲明
原創博文,轉載必須包含本聲明,保持本文完整,並以超鏈接形式注明作者Starrybird和本文原始地址:http://www.cnblogs.com/starrybird/p/4441022.html