pytorch raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}‘.format

When training the model, we need to find out whether there is multi GPU training

If using Python to load the model normally:

model.load_state_dict(torch.load(model_path))

If multi GPU training is used in training

model = torch.nn.DataParallel(model, device_ids=range(opt.ngpu))

If so, loading the model requires

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(model_path).items()})


Read More: