[Solved] Runtimeerror during dcgan training: found dtype long but expected float

When using dcgan for network training, the following errors occur:

RuntimeError: Found dtype Long but expected Float

The code snippet for this error is as follows:

label = torch.full((b_size,), real_label, device=device)
        # Input the batch with positive samples into the discriminant network for forward computation and put the result into the variable output
        output = netD(real_cpu).view(-1)
    
        # Calculate the loss
        errD_real = criterion(output, label)

The reason is that the data type of the input output data and tag value into the loss function does not match the required data type. What is required is float type data, and what is passed in is long type data
therefore, we need to convert the incoming data to float type
the modified code is as follows:

label = torch.full((b_size,), real_label, device=device)
        # Input the batch with positive samples into the discriminant network for forward computation and put the result into the variable output
        output = netD(real_cpu).view(-1)
        # Convert the incoming data to float type
        output = output.to(torch.float32)
        label = label.to(torch.float32)
        # Calculate the loss
        errD_real = criterion(output, label)

Problem solved!

Read More: