程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

什麼?Python太慢?試試Numba庫吧!

編輯:Python

什麼?Python太慢?試試Numba庫吧!

  • 官方文檔
  • Python編譯過程和執行原理
  • Numba簡介
  • Numba在何時是有效的
  • @jit裝飾器
    • signature參數(數據類型控制)
    • nopython、forceobj參數(編譯模式選擇)
    • nogil參數(全局進程鎖限制)
    • cache參數(保存為文件緩存)
    • parallel參數(並行化參數)
    • error_model參數
    • fastmath參數
    • locals參數
    • boundscheck參數
  • @generated_jit()裝飾器
  • @vectorize()裝飾器
  • @jitclass()裝飾器
  • @cfunc()裝飾器
  • 編寫規范
  • 性能對比的例子:SVD算法

官方文檔

官方文檔入口
有需要的小伙伴請點入享用

Python編譯過程和執行原理

本節參考了傳送門1和傳送門2
C/C++之類的編譯性語言編寫的程序,將源文件轉換成計算機使用的機器語言,經過鏈接器鏈接之後形成了二進制的可執行文件。運行該程序的時候,就可以把二進制程序從硬盤載入到內存中並運行。而在Python作為解釋型語言,沒有編譯這一步,而是由解釋器將源代碼轉換為字節碼,然後再由解釋器來執行這些字節碼,因此Python中不用擔心程序的編譯、庫的鏈接加載等問題。

Python解釋器(如CPython、IPython、PyPy、Jython、IronPython)執行Python代碼的四個過程

  1. 詞法分析
    檢查關鍵字等是否正確
  2. 語法分析
    語句格式是否正確
  3. 生成字節碼
    生成.pyc文件(PyCodeObject對象)。在編譯代碼的過程中,首先會將代碼中的函數、類等對象分類處理,然後生成字節碼文件。
  4. 執行
    Python解釋器對字節碼進行解釋,將每一行的字節碼解釋成CPU可以直接識別的機器碼,執行

常見的cpython解釋器是用c語言的方式來解釋字節碼的,
而numba則是使用LLVM編譯技術來解釋字節碼的。

我們之前寫的CPython本質上還是通過C的編譯器來替換掉CPython底層的復雜代碼從而實現了加速(比如Python的動態類型,涉及到一大堆的類型檢查、多態、溢出檢查等耗時非常多,但是如果使用CPython的靜態類型就沒有這麼多麻煩的問題),

而numba的思路則不太一樣,numba是在一個叫做LLVM的編譯器上進行編譯,Numba將Python字節碼轉換為LLVM中間表示(IR),請注意,LLVM IR是一種低級編程語言,與匯編語法類似,與Python無關。

Numba簡介

Numba是Python的一個即時編譯器,它最適合使用NumPy數組、函數和循環的代碼。使用Numba最常見的方式是通過它的裝飾器集合,這些裝飾器可以應用到您的函數中,以指示Numba編譯它們。當調用一個Numba修飾函數時,它被編譯成機器代碼以便及時執行,並且您的全部或部分代碼隨後可以以本機機器代碼的速度運行。

Numba在何時是有效的

這取決於您的代碼是什麼樣子的,如果您的代碼是面向數字的(做了很多數學工作),經常使用NumPy和/或有很多循環,那麼Numba通常是一個不錯的選擇。在這些示例中,我們將應用最基本的Numba的JIT裝飾器@jit來嘗試加速一些函數,以演示哪些工作正常,哪些工作不正常。

對於如下代碼,Numba很有有效

from numba import jit
import numpy as np
x = np.arange(100).reshape(10, 10)
@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
trace = 0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting
print(go_fast(x))

而對於如下代碼,Numba作用不大

from numba import jit
import pandas as pd
x = {
'a': [1, 2, 3], 'b': [20, 30, 40]}
@jit
def use_pandas(a): # Function will not benefit from Numba jit
df = pd.DataFrame.from_dict(a) # Numba doesn't know about pd.DataFrame
df += 1 # Numba doesn't understand what this is
return df.cov() # or this!
print(use_pandas(x))

@jit裝飾器

@numba.jit(signature=None, nopython=False, nogil=False, cache=False, forceobj=False, parallel=False, error_model='python', fastmath=False, locals={
}, boundscheck=False)

jit裝飾器中所有參數都是可選的,如果僅使用@jit,則由系統自動確定如何進行優化

signature參數(數據類型控制)

首先,我們看一個簡單的例子

from numba import jit
@jit
def f(x, y):
# A somewhat trivial example
return x + y

此時,使用f(10,3)和f(1j,2)都可以運行。但是,倘若需要對輸入參數和輸出參數的數據類型進行控制呢?

from numba import jit, int32, float32, double
@jit(double(float32, int32))
def f(x, y):
# A somewhat trivial example
return x + y

其中,double(float32, int32)就是函數簽名,double控制輸出參數的數據類型,float32和int32分別控制x和y的數據類型。輸出參數的簽名可以缺省,有庫自動判斷。
顯然,此時再使用f(1j,2)便會報錯,實現了數據類型的控制。

常用的數據簽名:

  • void表示返回值為空
  • intp和uintp表示pointer-sized整數(分別表示有符號和無符號)
  • intel和uintc相當於C中的int和unsigned int
  • int8 uint8, int16, uint16, int32, uint32, int64, uint64是相應的固定寬度的整數位寬度
  • float32 float64單、雙精度浮點數
  • complex64和complex128是單精度和雙精度的復數
  • 數組類型,如float32[:]和int8[:,:]

nopython、forceobj參數(編譯模式選擇)

兩個參數都是布爾類型,nopython為True表示編譯時使用nopython模式,而forceobj為True表示使用object模式。

nopython模式:生成不訪問Python C API的代碼的Numba編譯模式。這種編譯模式生成了性能最高的代碼,但要求能夠推斷出函數中所有值的本機類型。除非另有指示,否則如果不能使用nopython模式,@jit裝飾器將自動退回到對象模式。

object模式:一種Numba編譯模式,它生成將所有值作為Python對象處理的代碼,並使用Python C API對這些對象執行所有操作。在對象模式下編譯的代碼通常不會比Python解釋的代碼運行得快,除非Numba編譯器能夠利用循環j。

一般情況下,建議使用nopython模式,畢竟我們使用Numba的目的就是提高運行速度,但在編碼規范上有相應限制。

nogil參數(全局進程鎖限制)

若nogil為True表示釋放全局進程鎖,從而可以有效利用多核系統,但只能在nopython模式下使用。另外,使用時需要注意多線程編程的常見陷阱(一致性、同步、競爭條件等)。

cache參數(保存為文件緩存)

若cache為True,則緩存啟用基於文件的緩存,以便在前一次調用中已編譯函數時縮短編譯時間。

parallel參數(並行化參數)

若parallel為True,那麼可以自動地並行化許多常見的Numpy構造,並融合相鄰的並行操作,從而最大化緩存的局部性。

error_model參數

‘python’or’numpy’,決定以哪個庫為准拋出異常

fastmath參數

fastmath支持使用LLVM文檔中描述的不安全的浮點轉換。此外,如果Intel SVML安裝得更快,但是使用了一些數學內部特性的不太精確的版本。

locals參數

指定特定局部變量的類型

boundscheck參數

是否進行數組邊界的索引檢查,建議不做即設置為True,避免影響速度

@generated_jit()裝飾器

@numba.generated_jit(nopython=False, nogil=False, cache=False, forceobj=False, locals={
})

generated_jit()裝飾器可以根據傳入參數的類型決定函數不同的實現方式,同時能夠保證jit()裝飾器的速度

# 返回給定值是否是缺失類型
import numpy as np
from numba import generated_jit, types
@generated_jit(nopython=True)
def is_missing(x):
""" Return True if the value is missing, False otherwise. """
if isinstance(x, types.Float):
return lambda x: np.isnan(x)
elif isinstance(x, (types.NPDatetime, types.NPTimedelta)):
# The corresponding Not-a-Time value
missing = x('NaT')
return lambda x: x == missing
else:
return lambda x: False

@vectorize()裝飾器

@numba.vectorize(*, signatures=[], identity=None, nopython=True, target='cpu', forceobj=False, cache=False, locals={
})

編譯修飾函數,並將其包裝為Numpy ufunc或Numba DUFunc

@jitclass()裝飾器

jitclass()對將類中的函數都使用nopython模式進行編譯

import numpy as np
from numba import jitclass # import the decorator
from numba import int32, float32 # import the types
spec = [
('value', int32), # a simple scalar field
('array', float32[:]), # an array field
]
@jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
self.array = np.zeros(value, dtype=np.float32)
@property
def size(self):
return self.array.size
def increment(self, val):
for i in range(self.size):
self.array[i] = val
return self.array

@cfunc()裝飾器

cfunc()創建一個可以使用外部的C語言代碼進行調用的編譯後的程序,從而可以與用C或C++編寫的庫進行交互。考慮到很多Python庫的底層是C或C++,該功能十分有用。
例如,scipy.integrate.quad函數即可以接受普通的Python回調,也可以接受包裝在ctypes回調對象中的C回調。
使用普通的Python回調

import numpy as np
import scipy.integrate as si
def integrand(t):
return np.exp(-t) / t ** 2
def do_integrate(func):
""" Integrate the given function from 1.0 to +inf. """
return si.quad(func, 1, np.inf)
do_integrate(integrand)

使用cfunc()裝飾器

import numpy as np
import scipy.integrate as si
from numba import cfunc
def integrand(t):
return np.exp(-t) / t ** 2
def do_integrate(func):
""" Integrate the given function from 1.0 to +inf. """
return si.quad(func, 1, np.inf)
nb_integrand = cfunc("float64(float64)")(integrand)
do_integrate(nb_integrand.ctypes)

編寫規范

一部分報錯源於對數據類型不匹配,根據報錯調整即可;另一部分報錯可能源於numba不主持某些函數,具體可以查閱文檔,支持的python特性的鏈接,支持的numpy特性的鏈接

性能對比的例子:SVD算法

以推薦系統中的SVD算法為例,展示numba庫對速度的提升。

import numpy as np
import time
import pandas as pd
from numba import jit, prange
@jit(nopython=True, cache=True, nogil=True, parallel=True)
def svd(users, items, iterations, lr, reg, factors, avg, data):
# initialization
bu = np.random.normal(loc=0, scale=0.1, size=(users, 1))
bi = np.random.normal(loc=0, scale=0.1, size=(items, 1))
p = np.random.normal(loc=0, scale=0.1, size=(users, factors))
q = np.random.normal(loc=0, scale=0.1, size=(items, factors))
# iteration
for iteration in prange(iterations):
# error use: for u, i, r in trainset:
for line in prange(data.shape[0]):
u, i, r = data[line]
rp = avg + bu[u] + bi[i] + np.dot(q[i], p[u])
e_ui = r - rp
bu[u] += lr * (e_ui - reg * bu[u])
bi[i] += lr * (e_ui - reg * bi[i])
p[u] += lr * (e_ui * q[i] - reg * p[u])
q[i] += lr * (e_ui * p[u] - reg * q[i])
nUsers = 100 # number of users
nItems = 100 # number of items
iteration = 30 # number of iterations
lr = 0.01 # learning rate
reg = 0.002 # regularization rate
factor = 5 # number of factors
trainset = pd.read_csv("D:/py3/trainset.txt", sep=' ', header=None).values
aver = np.mean(trainset[:, 2]) # average rating
start = time.clock()
svd(nUsers, nItems, iteration, lr, reg, factor, aver, trainset)
end = time.clock()
print("training time: %s seconds" % (end - start))

倘若去掉@jit(nopython=True, cache=True, nogil=True),即不使用numba加速的結果為

training time: 17.564660734381764 seconds

倘若使用numba加速,第一次運行時因為需要編譯

training time: 9.588356848133849 seconds

之後,再次運行,用時就可以穩定在

training time: 0.18296860820512134 seconds

numba在速度上提升了96倍,一般來說numba可以提升一到兩個數量級。


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved