torch.load()
error reporting missing key(s) pytorch
Error condition: error loading the pre training model
RuntimeError: Error(s) in loading state_dict for : Missing key(s) in state_dict: “features.0.weight” … Unexpected key(s) in state_dict: “module.features.0.weight” …
Error reason:
The keywords of model parameters wrapped with nn.DataParallel will have an extra “module.” in front of them than the keywords of model parameters not wrapped with nn.DataParallel
Solution:
1. Loading nn.DataParallel(net) trained models using net.
Delete module.
# original saved file with DataParallel
state_dict = torch.load('model_path')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
net.load_state_dict(new_state_dict)
Code source
checkpoint = torch.load('model_path')
for key in list(checkpoint.keys()):
if 'model.' in key:
checkpoint[key.replace('model.', '')] = checkpoint[key]
del checkpoint[key]
net.load_state_dict(checkpoint)
Use nn.DataParallel when loading the model
checkpoint = torch.load('model_path')
net = torch.nn.DataParallel(net)
net.load_state_dict(checkpoint)
2. Load the net trained model using nn.DataParallel(net).
Before saving the weight, add module
If you use torch.save() when saving weights, use model.module.state_dict() to get model weight
torch.save(net.module.state_dict(), 'model_path')
Read the model before using nn.DataParallel and then use nn.
net.load_state_dict(torch.load('model_path'))
net = nn.DataParallel(net, device_ids=[0, 1])
Add module manually
net = nn.DataParallel(net)
from collections import OrderedDict
new_state_dict = OrderedDict()
state_dict =savepath #Pre-trained model path
for k, v in state_dict.items():
# add “module.” manually
if 'module' not in k:
k = 'module.'+k
else:
# Swap the location of modules and features
k = k.replace('features.module.', 'module.features.')
new_state_dict[k]=v
net.load_state_dict(new_state_dict)