When using dcgan for network training, the following errors occur:
RuntimeError: Found dtype Long but expected Float
The code snippet for this error is as follows:
label = torch.full((b_size,), real_label, device=device)
# Input the batch with positive samples into the discriminant network for forward computation and put the result into the variable output
output = netD(real_cpu).view(-1)
# Calculate the loss
errD_real = criterion(output, label)
The reason is that the data type of the input output data and tag value into the loss function does not match the required data type. What is required is float type data, and what is passed in is long type data
therefore, we need to convert the incoming data to float type
the modified code is as follows:
label = torch.full((b_size,), real_label, device=device)
# Input the batch with positive samples into the discriminant network for forward computation and put the result into the variable output
output = netD(real_cpu).view(-1)
# Convert the incoming data to float type
output = output.to(torch.float32)
label = label.to(torch.float32)
# Calculate the loss
errD_real = criterion(output, label)
Problem solved!