1.1 檢驗模型參數的更新幅度
optimizer.zero_grad()
model_output, pooler_output = model(input_data)
Before = list(model.parameters())[0].clone() # 獲取更新前模型的第0層權重
loss = criterion(model_output, label)
loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) # 梯度截斷
optimizer.step()
# 檢驗模型的學習情況
After = list(model.parameters())[0].clone() # 獲取更新後模型的第0層權重
predicted_label = torch.argmax(model_output, -1)
acc = accuracy_score(label.float().cpu(), predicted_label.view(-1).float().cpu())
print(loss,acc) # 打印mini-batch的損失值以及准確率
print('模型的第0層更新幅度:',torch.sum(After-Before))
如果:模型更新幅度非常小,其絕對值<0.01, 很可能是梯度消失了; 如果絕對值>1000,很可能是梯度爆炸;
具體阈值需要自行去調節,只是提供了一種思路
1.2 解決
(1)梯度爆炸
梯度爆炸常見原因:使用了深層網絡、參數初始化過大,解決方案:
1)更換優化器
2)學習率調低
3)梯度截斷
4)使用正則化
(2)梯度消失
梯度消失很有可能是:深層網絡、使用了sigmoid激活函數,解決方案:
1)使用Batch Norm 批標准化
BN將網絡中每一層的輸出標准化為正態分布,並使用縮放和平移參數對標准化之後的數據分布進行調整,可以將集中在梯度飽和區的原始輸出拉向線性變化區,增大梯度值,緩解梯度消失問題,並加快網絡的學習速度。
2)選用Relu()激活函數
3)使用殘差網絡ResNet
使用ResNet可以輕松搭建幾百層、上千層的網絡,而不用擔心梯度消失問題.