[Solved] RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of s

catalogue

An error is reported when running unet3 + for multi classification training

RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size xxx

1. Modify train.py

2. Modify predict.py

3. Modify eval.py

Then run the following statement to start the training:

[experiment record] u-net (pytorch)


An error is reported when running unet3 + for multi classification training

RuntimeError: 1only batches of spatial targets supported (non-empty 3D tensors) but got targets of size xxx

After many times of reference, trial and error, the following solutions are obtained:

Unet3 + source code has a small bug for multi classification tasks, which is slightly modified here. (unet3 + code comes from githubgithub – avbuffer/unet3plus_pth: unet3 +/UNET + +/UNET, used in deep automatic portal matching in Python)

1. Modify train.py

# line 56
if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()  
    else:
        criterion = nn.BCEWithLogitsLoss()

# line 79
loss = criterion(masks_pred, torch.squeeze(true_masks))   

# line 153
net = UNet(n_channels=3, n_classes=3)   

The reason for using the torch. Squeeze() function is that when crossentropy is used as the loss function, the output of output = net (input) should be [batchsize, n_class, height, weight], while the label is [batchsize, height, weight], and the label is a single channel gray map; Both bceloss and cross-entropy loss are used for classification problems. Bceloss is a special case of cross-entropy loss, which is only used for binary classification problems, and cross-entropy loss can be used for binary classification or multi-classification.

NN. Crossentropyloss() function to calculate cross entropy loss example:

Use.
# output is the output of the network, size=[batch_size, class]
# If the batch size of the network is 128 and the data is divided into 10 classes, then size=[128, 10]
 
# target is the real label of the data, which is scalar, size=[batch_size]
# If the batch size of the network is 128, then size=[128]
 
criterion=nn.CrossEntropyLoss()
crossentropyloss_output=criterion(output,target)

2. Modify predict.py

os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
if unet_type == 'v2':
    net = UNet2Plus(n_channels=3, n_classes=1)
elif unet_type == 'v3':
# line93
    net = UNet3Plus(n_channels=3, n_classes=20)   
    #net = UNet3Plus_DeepSup(n_channels=3, n_classes=1)
    #net = UNet3Plus_DeepSup_CGM(n_channels=3, n_classes=1)
else:
    net = UNet(n_channels=3, n_classes=1)

3. Modify eval.py

for true_mask, pred in zip(true_masks, mask_pred):
    pred = (pred > 0.5).float()
    if net.n_classes > 1:
# line26
    tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask).item()
    # tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
    else:
        tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()

Then run the following statement to start the training:

python train.py -g 0 -u v3 -e 400 -b 2 -l 0.1 -s 0.5 -v 15.0

Read More: