程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

【Pytorch學習筆記】7.繼承Module類構建模型時,子模塊的構建原理(基於OrderedDict)以及關於Python類的屬性賦值機制

編輯:Python

本文繼續探究學Pytorch時,涉及到的python底層的一些知識。

文章目錄

    • 問:繼承Module類構造模型時,背後是如何把自定義的子模塊組裝起來的?
    • 1 Module初始化時會構建多個OrderedDict(有序字典)存放子模塊
    • 2 Python的魔法方法__setattr__()定義了 類實例的屬性賦值 時的行為
    • 3 Module類實例在屬性賦值時會判斷屬性的類型存入對應的OrderedDict
    • 總結

問:繼承Module類構造模型時,背後是如何把自定義的子模塊組裝起來的?

我們學習模型構造時,是基於繼承nn.Module類定義模型來實現模型的構造的。
最簡單的構造方法一般就是2個:
①重寫父類的__init__構造函數,寫上自己想要的子模塊;②定義forward()正向傳播函數,將子模塊拼接起來。

比如我們構造一個多層感知機MLP模型:

import torch
from torch import nn
class MLP(nn.Module):
# __init__中聲明帶有模型參數的自定義層,這裡定義了兩個全連接層
def __init__(self, **kwargs):
# 調用父類Module的__init__進行必要的初始化。
super(MLP, self).__init__(**kwargs)
# 定義自己的子模塊。
self.hidden = nn.Linear(784, 256) # 隱藏層
self.act = nn.ReLU()
self.output = nn.Linear(256, 10) # 輸出層
# 定義模型的前向計算,即如何根據輸入X計算返回所需要的輸出。即拼接子模塊。
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
# 測試一下
X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)

結果:

可以看到我們在__init__中定義的屬性hidden、act、output所指定的子模塊被讀取到了模型net的信息中。
這也是重寫Module的__init__構造函數時的常用方法,在Pytorch的文檔中,Module部分開頭便指出了繼承Module構建模型的寫法:

即前面提到的:1.重寫__init__構造函數,添加子模塊;2.定義forward()正向傳播。

那麼繼承Module類構造模型時,背後是如何把自定義的子模塊組裝起來的?
為什麼在__init__中 以類屬性的形式 添加子模塊就可以自動被讀取為模型的組成部分呢?

1 Module初始化時會構建多個OrderedDict(有序字典)存放子模塊

我們看Module的源碼,可以看到__init__時會 _construct 許多OrderedDict(有序字典):

有序字典顧名思義就是有先後順序的Dict。這裡創建了許多類別的屬性,有我們認識的 parameters、buffers 和 modules,我們就可以推斷 _module 屬性就存放了我們未來自定義的子模塊,如Linear()等。

我們還看到 add_module 方法可以手動添加module:

那麼我們可以推斷,初始化Module時__init__中自定義的子模塊肯定也有方法,將它們添加進生成的空OrderedDict中。
那麼是什麼方法來把這些變量初始存入這個 _modules 屬性中的呢?
這就和Python本身的類屬性的賦值機制有關了。

2 Python的魔法方法__setattr__()定義了 類實例的屬性賦值 時的行為

在Python的Object基類中,我們定義類的時候自帶了魔法方法__setattr__()。
它的作用:當對類的實例 的各個屬性進行賦值時,首先自動調用__setattr__()方法,在該方法中實現將屬性名和屬性值添加到類實例的__dict__屬性中。

一般情況下我們不用重寫__setattr__(),實例化後的類當觸發屬性賦值事件時,會自動調用該方法,並存入__dict__屬性。舉個栗子:

class MyInfo:
ai = 'hello'
def __init__(self):
print(self.__dict__)
self.name = "Chopper"
print(self.__dict__)
self.age = 32
print(self.__dict__)
self.male = True
print(self.__dict__)
print(self.ai)
myinfo = MyInfo()
print('-----------------')
myinfo.female = False
myinfo.age = 20
print(myinfo.__dict__)
# 輸出:
{
}
{
'name': 'Chopper'}
{
'name': 'Chopper', 'age': 32}
{
'name': 'Chopper', 'age': 32, 'male': True}
hello
-----------------
{
'name': 'Chopper', 'age': 20, 'male': True, 'female': False}
# 這裡 ai 沒有觸發屬性賦值機制,所以不會存在於__dict__中。

3 Module類實例在屬性賦值時會判斷屬性的類型存入對應的OrderedDict

在Module中,我們定義子模塊的過程就是給Module定義屬性並賦值的過程(如self.hidden = nn.Linear(784, 256)),因此觸發__setattr__()。
在Module中,__setattr__()被重寫,屬性的值會被先拿來判斷一次,如被判斷為 Module、Parameter 或 Buffer,便存入__dict__事先建起來的有序字典_module_parameters_buffers中。剩下未分類屬性信息存入__dict__的末尾。
具體源碼可參考Module的__setattr__()部分,這裡不展示了。

我們把開頭的代碼修改一下,我們一次次地觀察__dict__的變化,可以看到:

import torch
from torch import nn
class MLP(nn.Module):
# __init__中聲明帶有模型參數的自定義層,這裡定義了兩個全連接層
def __init__(self, **kwargs):
# 調用父類Module的__init__進行必要的初始化。
super(MLP, self).__init__(**kwargs)
# 定義自己的子模塊。
print(self.__dict__ ,'\n')
self.hidden = nn.Linear(784, 256)
print(self.__dict__ ,'\n')
self.act = nn.ReLU()
print(self.__dict__ ,'\n')
self.output = nn.Linear(256, 10)
print(self.__dict__ ,'\n')
self.origin = 1.0
print(self.__dict__ ,'\n')
# 定義模型的前向計算,即如何根據輸入X計算返回所需要的輸出。即拼接子模塊。
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
# 測試一下
net = MLP()
print(net)
print(net._modules['hidden']._parameters)

結果:

可以看到net中的各個OrderedDict類別屬性和其他的屬性,子module模塊 Linear()、ReLU()等 被分類到了_module屬性中。
而打印net會顯示各個module的信息(由另一個魔法方法__str__定義)。

同時,module被存為了OrderedDict後,還可以方便通過標簽訪問該有序字典的子module中的信息(比如上面代碼最後一行)。
每個Module類都有這一套的OrderedDict,關系上層層套疊,也方便維護和管理。
不僅是Module本身,Module的Parameter、Buffer等也都可以在這個基於OrderedDict嵌套的樹形結構中維護,也體現了Module這個框架的精密性。

總結

  1. 我們基於繼承nn.Module構建模型時,一般實現以下2個步驟:
    ①重寫父類的__init__構造函數,寫上自己想要的子模塊;②定義forward()正向傳播函數,將子模塊拼接起來。
  2. Module初始化時會構建多個OrderedDict(有序字典)存放不同類別的模塊,如:子module、parameters、buffers等。
  3. Module重寫了魔法方法__setattr__(),類實例的屬性賦值後先判斷屬性類型,將 子module、parameter、buffer等存入__dict__(屬性信息)對應的OrderedDict中。OrderedDict類似樹形結構,方便維護和管理。

  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved