程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

python項目調優合集(長期更新):梯度爆炸、消失

編輯:Python

文章目錄

    • 1. loss基本保持不變,acc又很低

1. loss基本保持不變,acc又很低

1.1 檢驗模型參數的更新幅度

 optimizer.zero_grad()
model_output, pooler_output = model(input_data)
Before = list(model.parameters())[0].clone() # 獲取更新前模型的第0層權重
loss = criterion(model_output, label)
loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) # 梯度截斷
optimizer.step()
# 檢驗模型的學習情況
After = list(model.parameters())[0].clone() # 獲取更新後模型的第0層權重
predicted_label = torch.argmax(model_output, -1)
acc = accuracy_score(label.float().cpu(), predicted_label.view(-1).float().cpu())
print(loss,acc) # 打印mini-batch的損失值以及准確率
print('模型的第0層更新幅度:',torch.sum(After-Before))

如果:模型更新幅度非常小,其絕對值<0.01, 很可能是梯度消失了; 如果絕對值>1000,很可能是梯度爆炸;

具體阈值需要自行去調節,只是提供了一種思路

1.2 解決
(1)梯度爆炸
梯度爆炸常見原因:使用了深層網絡、參數初始化過大,解決方案:
1)更換優化器
2)學習率調低
3)梯度截斷
4)使用正則化
(2)梯度消失
梯度消失很有可能是:深層網絡、使用了sigmoid激活函數,解決方案:
1)使用Batch Norm 批標准化
BN將網絡中每一層的輸出標准化為正態分布,並使用縮放和平移參數對標准化之後的數據分布進行調整,可以將集中在梯度飽和區的原始輸出拉向線性變化區,增大梯度值,緩解梯度消失問題,並加快網絡的學習速度。

2)選用Relu()激活函數
3)使用殘差網絡ResNet
使用ResNet可以輕松搭建幾百層、上千層的網絡,而不用擔心梯度消失問題.


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved