矩陣最近鄰填充是指對矩陣中指定元素取值用周圍最近鄰的元素取值進行替換。下面介紹三種實現方法。前兩種方法適合較小的輸入矩陣,第三種方法速度最快。
采用for循環的方式,逐個計算待替換元素位置與剩余非替換元素位置的距離,選擇出其中距離最小位置的元素為替換目標。該方法用於采用for循環的形式,所以顯存占用較少。mei函數定義如下所示:
def nearest_fill(x, a, b):
"""
Parameters
----------
x : array,輸入矩陣
a : float,待替換元素取值下限
b : float,待替換元素取值上限
Returns
-------
x : array,替換後矩陣
"""
x = x.copy()
index_r = np.array(np.where((x>=a) & (x <=b))).T
index_t = np.array(np.where((x<a) | (x >b))).T
for ir in index_r:
d = (ir - index_t) ** 2
d = d[:, 0] + d[:, 1]
p = index_t[np.argmin(d)]
x[ir[0], ir[1]] = x[p[0], p[1]]
return x
采用矩陣的方式,逐個計算待替換元素位置與剩余非替換元素位置的距離,選擇出其中距離最小位置的元素為替換目標。這種方式會存儲全部距離矩陣,占用的內存或者顯存較大。可以通過設置batch的大小來降低內存或者顯存占用。函數定義如下所示:
x : array,輸入矩陣
a : float,待替換元素取值下限
b : float,待替換元素取值上限
batch:設置距離矩陣一次操作最多填充的數量,設置較小的值則內存或顯存較小
def nearest_fill_batch(x, a, b, batch=4):
"""
Parameters
----------
x : array,輸入矩陣
a : float,待替換元素取值下限
b : float,待替換元素取值上限
batch:設置距離矩陣一次操作最多填充的數量,設置較小的值則內存或顯存較小
Returns
-------
x : array,替換後矩陣
"""
x = x.copy()
index_r = np.array(np.where((x>=a) & (x <=b))).T
index_t = np.array(np.where((x<a) | (x >b))).T
index_r1 = np.expand_dims(index_r, 1)
N = index_r.shape[0] // batch
for i in range(N):
index_r2 = index_r1[i*batch:(i+1)*batch]
distances = index_r2 - index_t
distances = distances ** 2
distances = np.sum(distances, -1)
pos = np.argmin(distances, 1)
x[index_r2[:, 0, 0], index_r2[:, 0, 1]] = x[index_t[pos][:, 0], index_t[pos][:, 1]]
if index_r.shape[0] % batch > 0:
index_r2 = index_r1[(index_r.shape[0]- index_r.shape[0] % batch):]
distances = index_r2 - index_t
distances = distances ** 2
distances = np.sum(distances, -1)
pos = np.argmin(distances, 1)
x[index_r2[:, 0, 0], index_r2[:, 0, 1]] = x[index_t[pos][:, 0], index_t[pos][:, 1]]
return x
前兩種方法都會計算每個待替換元素與剩余所有非替換元素的位置距離,計算量大,且內存或顯存占用大。這種方法適合維度較小的矩陣元素替換。對於較大維度的矩陣,以上兩種方法的計算時間也會顯著增加。第三種方法,僅僅計算其中位置距離較小的若干個元素的距離矩陣,運行速度大大提升。函數定義如下所示:
x : array,輸入矩陣
a : float,待替換元素取值下限
b : float,待替換元素取值上限
def nearest_fill_fast(x, a, b):
"""
Parameters
----------
x : array,輸入矩陣
a : float,待替換元素取值下限
b : float,待替換元素取值上限
Returns
-------
y : array,替換後矩陣
"""
y = x.copy()
h, w = x.shape[:2]
index_r = np.array(np.where((x>=a) & (x <=b))).T
index_t = np.array(np.where((x<a) | (x >b))).T
for ir in index_r:
i, j = ir
for k in range(3, max(w, h)):
kernel = [k, k]
x1 = max(j-kernel[0]//2, 0)
x2 = min(x1 + kernel[0], w)
y1 = max(i-kernel[1]//2, 0)
y2 = min(y1 + kernel[1], h)
x_tmp = x[y1:y2, x1:x2]
index_t = np.array(np.where((x_tmp<a) | (x_tmp >b))).T
if len(index_t) < 1:
continue
index_t += [y1, x1]
d = (ir - index_t) ** 2
d = d[:, 0] + d[:, 1]
p = index_t[np.argmin(d)]
y[ir[0], ir[1]] = x[p[0], p[1]]
break
return y
耗時測試結果:
更多三維、二維感知算法和金融量化分析算法請關注“樂樂感知學堂”微信公眾號,並將持續進行更新。