model.load_state_dict(torch.load('/home/yangjy/projects/Jane_git_tf/weights/con_model/best1_2022-12-02-09-36.pth', map_location=device))
Question?情形:
新的model是需要兩個模型作前期的處理后的結(jié)果,如model1得到feature1,model2得到feature2,最終(現(xiàn)在訓練的)model(model3)需要學習的是根據(jù)feature1和feature2進行整合和特征學習正確分辨出最終的結(jié)果。這個時候model3在第一次訓練做初始化的時候需要加載model1和model2的權重,但是后來訓練的時候如果初始權重是之前訓練好的model3的權重,就不要再加載model1和model2的權重后再加載model3的權重,機器在加載的過程中都是需要消耗時間的,一方面是資源成本的浪費,無論是時間成本還是內(nèi)存占用率都是很大的消耗;其次剛剛我發(fā)現(xiàn),這樣重復性加載時影響最終的模型訓練效果的,模型在加載權重的過程個人建議不要寫在模型初始化的過程中,這種不靈活的寫法,很可能會產(chǎn)生bias!!
class model3(nn.Module):
def __init__(self, num_classes, device,model1_path, model2_path,freeeze_pretain):
super(Conv_con, self).__init__()
self.device = device
self.model1= MiniConvNext(num_classes=5, depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768], )
self.model2 = MiniConvNext(num_classes=1, depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768], )
self.freeeze_pretain = freeeze_pretain
self._init_weights()
self.fctl = _FCtL(512, 512)
self.norm = LayerNorm(512, eps=1e-6, data_format="channels_last")
self.head = nn.Linear(512, num_classes)
self.model1_path= model1_path
self.model2_path= model2_path
self.set_pretrained_weight()
def set_pretrained_weight(self):
if self.model1_path:
pretrained_dict = torch.load(self.model1_path, map_location=self.device)
model_dict = self.model1.state_dict()
# 1. filter out unnecessary keys
pretrained_dict_b = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict_b)
self.model1.load_state_dict(model_dict)
self.model1.eval()
if self.model2_path:
pretrained_dict = torch.load(self.model2_path, map_location=self.device)
model_dict = self.model2.state_dict()
# 1. filter out unnecessary keys
pretrained_dict_b = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict_b)
self.model2.load_state_dict(model_dict)
if self.freeeze_pretain: # ?????fctl????????????
self.model2.eval()
if self.freeeze_pretain: # ??FCTL
for name, para in self.model2.named_parameters():
para.requires_grad = False
for name, para in self.model1.named_parameters():
para.requires_grad = False
else: # ??FCTL?global??
for name, para in self.model1.named_parameters():
para.requires_grad = False
def get_pretrained_weight(self):
for name, parm in self.model2.named_parameters():
print(f'{name}:{parm.requires_grad}')
def forward(self, x, y):
if self.freeeze_pretain: # ?????????????????????
with torch.no_grad():
feature1 = self.model1(x)
feature2 = self.model2(y)
else:
with torch.no_grad():
feature1 = self.model1(y)
feature2 = self.model2(x)
features = self.fctl(feature1 ,feature2 ,) # ??global?????roi????
features_1 = self.norm(features.mean([-2, -1]))
out = self.head(features_1)
return out
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.2)
nn.init.constant_(m.bias, 0)
勸大家不要這樣寫!不要把權重加載的事情放在初始化里面,追悔莫及!
思考:
其實我自己剛開始覺得這種重復加載權重應該是沒有問題的,因為model1和model2只是model3的一部分,我先加載model1和model2的權重,最后加載model3的權重也是會覆蓋剛剛加載的model1和model2的權重的,但是結(jié)果好像并不像我想想的那么簡單。因為我用訓練好的權重去預測,先加載model1,再加載model2,之后加載model3,之后得到的結(jié)果驚掉下巴!雖然再訓練過程中在驗證集上準確率不低,但是…所以用驗證集驗證是不是我的權重保存有問題。check后發(fā)現(xiàn)沒有問題,之后檢查數(shù)據(jù)集也沒有問題,代碼也沒問題,label的錯誤之前犯過了,也不妨再檢查一遍沒有問題。所以我又重新定義了不加載權重的predict_model ,直接加載model3的權重,這次在驗證集上的結(jié)果才是正常的。至于原因,個人還在探索,搞明白再和大家分享。
你是否還在尋找穩(wěn)定的海外服務器提供商?創(chuàng)新互聯(lián)www.cdcxhl.cn海外機房具備T級流量清洗系統(tǒng)配攻擊溯源,準確流量調(diào)度確保服務器高可用性,企業(yè)級服務器適合批量采購,新人活動首月15元起,快前往官網(wǎng)查看詳情吧
標題名稱:深度學習(16)——權重加載-創(chuàng)新互聯(lián)
分享路徑:http://vcdvsql.cn/article42/ggihc.html
成都網(wǎng)站建設公司_創(chuàng)新互聯(lián),為您提供域名注冊、企業(yè)網(wǎng)站制作、定制開發(fā)、網(wǎng)站建設、動態(tài)網(wǎng)站、網(wǎng)頁設計公司
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容