torch.load()
error reporting missing key(s) pytorch
Error condition: error loading the pre training model
RuntimeError: Error(s) in loading state_dict for : Missing key(s) in state_dict: “features.0.weight” … Unexpected key(s) in state_dict: “module.features.0.weight” …
Error reason:
The keywords of model parameters wrapped with nn.DataParallel will have an extra “module.” in front of them than the keywords of model parameters not wrapped with nn.DataParallel
Solution:
1. Loading nn.DataParallel(net) trained models using net.
Delete module.
# original saved file with DataParallel
state_dict = torch.load('model_path')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
net.load_state_dict(new_state_dict)
Code source
checkpoint = torch.load('model_path')
for key in list(checkpoint.keys()):
if 'model.' in key:
checkpoint[key.replace('model.', '')] = checkpoint[key]
del checkpoint[key]
net.load_state_dict(checkpoint)
Use nn.DataParallel when loading the model
checkpoint = torch.load('model_path')
net = torch.nn.DataParallel(net)
net.load_state_dict(checkpoint)
2. Load the net trained model using nn.DataParallel(net).
Before saving the weight, add module
If you use torch.save() when saving weights, use model.module.state_dict() to get model weight
torch.save(net.module.state_dict(), 'model_path')
Read the model before using nn.DataParallel and then use nn.
net.load_state_dict(torch.load('model_path'))
net = nn.DataParallel(net, device_ids=[0, 1])
Add module manually
net = nn.DataParallel(net)
from collections import OrderedDict
new_state_dict = OrderedDict()
state_dict =savepath #Pre-trained model path
for k, v in state_dict.items():
# add “module.” manually
if 'module' not in k:
k = 'module.'+k
else:
# Swap the location of modules and features
k = k.replace('features.module.', 'module.features.')
new_state_dict[k]=v
net.load_state_dict(new_state_dict)
Read More:
- Pytorch Loading model error: RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict
- pytorch RuntimeError: Error(s) in loading state_ Dict for dataparall… Import model error solution
- [Solved] Pytorch Error: RuntimeError: Error(s) in loading state_dict for Network: size mismatch
- [Solved] RuntimeError: Error(s) in loading state_dict for BertForTokenClassification
- [Solved] Error(s) in loading state_dict for GeneratorResNet
- [Solved] RuntimeError: Error(s) in loading state_dict for Net:
- [Solved] RuntimeError: Error(s) in loading state dict for YOLOX:
- pytorch model.load_state_dict Error [How to Solve]
- [Solved] bushi RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/s
- [Solved] Yolov5 Deep Learning Error: RuntimeError: DataLoader worker (pid(s) 2516, 1768) exited unexpectedly
- [ONNXRuntimeError] : 10 : INVALID_Graph loading model error
- [Solved] ERROR: URL ‘s3://‘ is supported but requires these missing dependencies: [‘s3fs‘]. To install dvc wi
- How to Solve keras load_model() Error
- [Pytorch Error Solution] Pytorch distributed RuntimeError: Address already in use
- linux ubuntu pip search Fault: <Fault -32500: “RuntimeError: PyPI‘s XMLRPC API is currently disab
- [Solved] PyTorch Caught RuntimeError in DataLoader worker process 0和invalid argument 0: Sizes of tensors mus
- How to Solve Pytorch DataLoader Loading Error: UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0xe5 in position 1023
- [Solved] D2lzh_Pytorch Import error: importerror: DLL load failed while importing
- [Solved] RuntimeError : PyTorch was compiled without NumPy support
- [Solved] pytorch loss.backward() Error: RuntimeError: Function AddBackward0 returned an invalid gradient at index 1…