Tag Archives: state_dict

Pytorch Loading model error: RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict

When the model is saved, it is saved with key pairs. At the same time, when loading, find the key value corresponding to the model according to the key value of the current network, and then load it. Generally, an error is reported because the key values of the model and the network do not match.

1. The most common problem is that there are too many key values or less module

In this case, the key value saved by the model after dataparallel or DDP training has module  , The key value of the corresponding network has no module

1) You can:

model = nn.DataParallel(model)

Add the key value of the model to module

2) You can also modify the key value by traversing the key pair value of the model.

For example, delete redundant modules when loading models   The code is as follows

state_dict = torch.load(load_path)
for key, param in state_dict.items():
    if key.startswith('module.'):        
        state_dict[key[7:]] = param          
        state_dict.pop(key)
net.load_state_dict(state_dict)
        

2. Explain load in detail_state_False parameter of dict (state_dict, false)

Many tutorials say that if the names do not match, you can directly add the false parameter, but you need to pay attention to a big pit here.

If the key value of the model does not match the key value of the network, the model will not load the pre training parameters, although no error will be reported.

The false parameter is used to the non-strict matching loading model can be analyzed in the following cases.

1) The model contains some parameters of the network

For example, the model is resnet101, and your current network is resnet50. Assuming that the parameter name of resnet50 is included in the parameters of resnet101, using false directly will load parameters with the same key value for your network resnet50. This avoids circular matching of each key pair value of resnet101 to see if it is required by resnet50.

2) The model does not contain the parameters of the network at all

As shown in case 1, the model has 100 parameters, all of which contain ‘module.’, and the network also has 100 parameters, all of which do not have ‘module.’. In this case, if the parameter is set to false, it will be found that no key-value can match, so the network will not load any parameters.

3) Introduce another false usage scenario

For example, in the distillation network pisr, the teacher network includes encoder and decoder, and the student network is composed of decoder. Therefore, when training the student network, if you want to load the pre-training model saved by the teacher network, setting false will automatically identify that the key values of the decoder are the same, and then load it.

To sum up, after setting the false parameter, the parameters are still loaded according to the key value. How many key values match, how many model parameters are loaded.

3. As long as the parameter size is the same, it can be loaded

For example, I have a 10 layer network model and a 3-layer network. I want to load the parameters of layer 9 into layer 1 of the current network. If the parameters have the same size, you can traverse the key pair values. Load the parameter into the desired key value.

state_dict = torch.load(load_path)
new_state_dict = []
for key, param in state_dict.items():
    if 'conv9' in key:        
        new_state_dict[key.replace('conv9', 'conv1')] = param   
net.load_state_dict(new_state_dict)