Tag Archives: PyTorch Load Model Error

[Solved] PyTorch Load Model Error: Missing key(s) RuntimeError: Error(s) in loading state_dict for

torch.load() error reporting missing key(s) pytorch

Error condition: error loading the pre training model

Plain text
Copy to clipboard
Open code in new window
EnlighterJS 3 Syntax Highlighter
RuntimeError: Error(s) in loading state_dict for :
Missing key(s) in state_dict: “features.0.weight” …
Unexpected key(s) in state_dict: “module.features.0.weight” …
RuntimeError: Error(s) in loading state_dict for : Missing key(s) in state_dict: “features.0.weight” … Unexpected key(s) in state_dict: “module.features.0.weight” …
RuntimeError: Error(s) in loading state_dict for :

Missing key(s) in state_dict: features.0.weight 

Unexpected key(s) in state_dict: module.features.0.weight 


Error reason:

The keywords of model parameters wrapped with nn.DataParallel will have an extra “module.” in front of them than the keywords of model parameters not wrapped with nn.DataParallel



1. Loading nn.DataParallel(net) trained models using net.

Delete module.

# original saved file with DataParallel
state_dict = torch.load('model_path')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params

Code source

checkpoint = torch.load('model_path')
for key in list(checkpoint.keys()):
    if 'model.' in key:
        checkpoint[key.replace('model.', '')] = checkpoint[key]
        del checkpoint[key]


Use nn.DataParallel when loading the model

checkpoint = torch.load('model_path')
net = torch.nn.DataParallel(net)

2. Load the net trained model using nn.DataParallel(net).

Before saving the weight, add module

If you use torch.save() when saving weights, use model.module.state_dict() to get model weight

torch.save(net.module.state_dict(), 'model_path')

Read the model before using nn.DataParallel and then use nn.

net = nn.DataParallel(net, device_ids=[0, 1]) 

Add module manually

net = nn.DataParallel(net) 
from collections import OrderedDict
new_state_dict = OrderedDict()
state_dict =savepath #Pre-trained model path
for k, v in state_dict.items():
	# add “module.” manually
    if 'module' not in k:
        k = 'module.'+k
    # Swap the location of modules and features
        k = k.replace('features.module.', 'module.features.')
