RuntimeError: each element in

Runtimeerror: each element in list of batch should be of equal size
define your own dataset class, return the corresponding data to be returned, and find the following error

RuntimeError: each element in list of batch should be of equal size

Baidu said that the most direct way is to batch it_ The value of size is changed to 1, and the error report is released. But I’m training models, not just to correct mistakes. batch_ How to train the model when size is set to 1, so I decided to study this error.

Original Traceback (most recent call last):
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 81, in default_collate
    raise RuntimeError('each element in list of batch should be of equal size')

According to the error information, you can find the source of the error. Py source code, the error appears in the default_ In the collate() function. Baidu found this source defaul_ The collate function is the default batch processing method of the dataloader class. If the collate function is not used when defining the dataloader_ If the FN parameter specifies a function, the method in the following source code will be called by default. If you have the above error, it should be the last four line error in this function

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

This function is to pass in a batch data tuple, in which each data is in the dataset class you defined__ getitem__() Method. The length of the tuple is your batch_ Size sets the size of the. However, one of the fields of the iteratable object returned by the dataloader class is the batch_ The corresponding fields of size samples are spliced together. Therefore, when this method is called by default, it will enter the penultimate line for the first time return [default]_ Collate (samples) for samples in translated] use the zip function to generate an iterative object from the batch tuple. Then the same field is retrieved by iteration and the default is recursively re passed in_ In the collate() function, take out the first field to judge whether the data type is in the type listed above, then the dateset content can be returned correctly
if the batch data is processed in the above order, the above error will not occur. If the data of the element is not in the listed data type after the second recursion, it will still enter the next recursion, that is, the third recursion. At this time, even if the data can be returned normally, it does not meet our requirements, and the error report generally appears after the third recursion. Therefore, if you want to solve this error, you need to carefully check the data type of the return field of your defined dataset class. It can also be found in defaule_ In the collate() method, output the batch content before and after processing, and view the specific processing flow of the function to help you find the error of the returned field data type
tips: don’t change defaule in the source code file_ The collate () method can copy this code and define its own collate_ Fn() function and specify your own collet when instantiating the dataloader class_ FN function
I hope you can solve the bug as soon as possible and run through the model!

Read More: