官方的pretrain model去除指定层可以参考链接https://blog.csdn.net/KHFlash/article/details/82345441,这里主要针对非官方的pretrain model,如下:
import torch
from collections import OrderedDict
import os
import torch.nn as nn
import torch.nn.init as init
from xxx import new_VGG
def init_weight(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal(0,0.01)
m.bias.data.zero_()
def copyStateDict(state_dict):
if list(state_dict.keys())[0].startswith('module'):
start_idx = 1
else:
start_idx = 0
new_state_dict = OrderedDict()
for k,v in state