[Solved] pytorch Load Error: “RuntimeError: Error(s) in loading state_dict for Sequential:”

RuntimeError: Error(s) in loading state_dict for Sequential:

This error is usually related to the use of nn.DataParallel for training

It means that the string in the model’s parameter key does not match the string in the key fetched by torch.load

Therefore, we just need to modify the dict obtained by torch.load to make it match.

 

Example

When I torch.save, the string in the parameter key is automatically prepended with ‘module.’

Therefore, after torch.load, we need to remove ‘module.

The method is as follows.

model = Model()
model_para_dict_temp = torch.load('xxx.pth')
model_para_dict = {}
for key_i in model_para_dict_temp.keys():
    model_para_dict[key_i[7:]] = model_para_dict_temp[key_i]  # Delete 'module.'
del model_para_dict_temp
model.load_state_dict(model_para_dict)

Read More:

Leave a Reply

Your email address will not be published. Required fields are marked *