[Solved] torch Do Targer Detection Error: RuntimeError: CUDA error: device-side assert triggered

When training torchvision’s maskrcnn with your own data, the following errors are reported:

Traceback (most recent call last):
  File "main_train_detection.py", line 232, in <module>
    main(params)
  File "main_train_detection.py", line 201, in main
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
  File "/raid/huaqing/tyler/suzhou/code/utils/engine.py", line 37, in train_one_epoch
    loss_dict = model(images, targets)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/generalized_rcnn.py", line 97, in forward
    detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py", line 760, in forward
    loss_classifier, loss_box_reg = fastrcnn_loss(
  File "/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py", line 40, in fastrcnn_loss
    sampled_pos_inds_subset = torch.where(labels > 0)[0]
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

The root cause is that the category label is not numbered from 0:
there are actually three categories of targets to be identified, so the total number of categories set is 3 Then set the corresponding relationship between category labels and categories as follows:

cls_dict = {'holes':1, 'marker':2, 'band':3}.

When numbering category labels, they are actually numbered from 0. For a total of 3 categories, the label numbers are 0, 1 and 2 respectively In other words, there is no label = = 3 category Therefore, the above CLS_Dict is adopted, which will cause the number of band class to overflow It should be corrected as follows:

cls_dict = {'holes':0, 'marker':1, 'band':2}

Read More: