RuntimeError: Error(s) in loading state_dict for Sequential:
This error is usually related to the use of nn.DataParallel for training
It means that the string in the model’s parameter key does not match the string in the key fetched by torch.load
Therefore, we just need to modify the dict obtained by torch.load to make it match.
Example
When I torch.save, the string in the parameter key is automatically prepended with ‘module.’
Therefore, after torch.load, we need to remove ‘module.
The method is as follows.
model = Model() model_para_dict_temp = torch.load('xxx.pth') model_para_dict = {} for key_i in model_para_dict_temp.keys(): model_para_dict[key_i[7:]] = model_para_dict_temp[key_i] # Delete 'module.' del model_para_dict_temp model.load_state_dict(model_para_dict)