When the model is saved, it is saved with key pairs. At the same time, when loading, find the key value corresponding to the model according to the key value of the current network, and then load it. Generally, an error is reported because the key values of the model and the network do not match.
1. The most common problem is that there are too many key values or less module
In this case, the key value saved by the model after dataparallel or DDP training has module , The key value of the corresponding network has no module
1) You can:
model = nn.DataParallel(model)
Add the key value of the model to module
2) You can also modify the key value by traversing the key pair value of the model.
For example, delete redundant modules when loading models The code is as follows
state_dict = torch.load(load_path)
for key, param in state_dict.items():
if key.startswith('module.'):
state_dict[key[7:]] = param
state_dict.pop(key)
net.load_state_dict(state_dict)
2. Explain load in detail_state_False parameter of dict (state_dict, false)
Many tutorials say that if the names do not match, you can directly add the false parameter, but you need to pay attention to a big pit here.
If the key value of the model does not match the key value of the network, the model will not load the pre training parameters, although no error will be reported.
The false parameter is used to the non-strict matching loading model can be analyzed in the following cases.
1) The model contains some parameters of the network
For example, the model is resnet101, and your current network is resnet50. Assuming that the parameter name of resnet50 is included in the parameters of resnet101, using false directly will load parameters with the same key value for your network resnet50. This avoids circular matching of each key pair value of resnet101 to see if it is required by resnet50.
2) The model does not contain the parameters of the network at all
As shown in case 1, the model has 100 parameters, all of which contain ‘module.’, and the network also has 100 parameters, all of which do not have ‘module.’. In this case, if the parameter is set to false, it will be found that no key-value can match, so the network will not load any parameters.
3) Introduce another false usage scenario
For example, in the distillation network pisr, the teacher network includes encoder and decoder, and the student network is composed of decoder. Therefore, when training the student network, if you want to load the pre-training model saved by the teacher network, setting false will automatically identify that the key values of the decoder are the same, and then load it.
To sum up, after setting the false parameter, the parameters are still loaded according to the key value. How many key values match, how many model parameters are loaded.
3. As long as the parameter size is the same, it can be loaded
For example, I have a 10 layer network model and a 3-layer network. I want to load the parameters of layer 9 into layer 1 of the current network. If the parameters have the same size, you can traverse the key pair values. Load the parameter into the desired key value.
state_dict = torch.load(load_path)
new_state_dict = []
for key, param in state_dict.items():
if 'conv9' in key:
new_state_dict[key.replace('conv9', 'conv1')] = param
net.load_state_dict(new_state_dict)
Read More:
- [Solved] PyTorch Load Model Error: Missing key(s) RuntimeError: Error(s) in loading state_dict for
- pytorch RuntimeError: Error(s) in loading state_ Dict for dataparall… Import model error solution
- [Solved] Pytorch Error: RuntimeError: Error(s) in loading state_dict for Network: size mismatch
- [Solved] RuntimeError: Error(s) in loading state_dict for BertForTokenClassification
- [Solved] RuntimeError: Error(s) in loading state_dict for Net:
- [Solved] Error(s) in loading state_dict for GeneratorResNet
- [Solved] RuntimeError: Error(s) in loading state dict for YOLOX:
- pytorch model.load_state_dict Error [How to Solve]
- [ONNXRuntimeError] : 10 : INVALID_Graph loading model error
- [Solved] bushi RuntimeError: version_ <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /pytorch/caffe2/s
- `Model.XXX` is not supported when the `Model` instance was constructed with eager mode enabled
- How to Solve Pytorch DataLoader Loading Error: UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0xe5 in position 1023
- [Pytorch Error Solution] Pytorch distributed RuntimeError: Address already in use
- [Solved] Yolov5 Deep Learning Error: RuntimeError: DataLoader worker (pid(s) 2516, 1768) exited unexpectedly
- To solve the problem that the loss of verification set of resnet50 pre-training model remains unchanged
- YOLOX Model conversion error: [TensorRT] ERROR: runtime.cpp (25) – Cuda Error in allocate: 2 (out of memory)
- [Solved] Python 3.6 Error: ‘dict’ object has no attribute ‘has_key’
- [Solved] HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/saved_model
- [Solved] PyTorch Caught RuntimeError in DataLoader worker process 0和invalid argument 0: Sizes of tensors mus
- How to Solve keras load_model() Error