How does pytorch DataLoader gather data from Dataset into batches?
Dataset
and DataLoader
is the basic shipped method of preparing and feeding data when training models in pytorch. The official docs does a great job on showing how these two interact to provide an easier, cleaner way to feed data.
But even after following through this great tutorial, I still wasn’t sure how exactly DataLoader
gathered the data returned in Dataset
into a batch data.
The Dataset
doesn’t restrict the user on how the data should be returned. It can return one object or multiple objects. But how does the DataLoader
know how to bundle multiple return object/objects?
This is done by the default collate
function in DataLoader
and it turns out the default collate function is written well enough to handle whatever the Dataset
throws.
The following are the parts of the pytorch source code related to this topic.
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
if the DataLoader
has not been applied with a custom collate function, it will use the default one.
def default_collate(batch):
r”””Puts each data field into a tensor with outer dimension batch size”””elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we’re in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == ‘numpy’ and elem_type.__name__ != ‘str_’ \
and elem_type.__name__ != ‘string_’:
if elem_type.__name__ == ‘ndarray’ or elem_type.__name__ == ‘memmap’:
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, ‘_fields’): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(‘each element in list of batch should be of equal size’)
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
This is the core of default collate function. We can see that it will pack the given batch of data into torch datatypes. Even if individual data of given batch is a python list, tuple, dictionary, this function will recursively call upon its elements until it packs the final element into tensors.
This recursive behavior is useful, but it forces each sample in batch to have equal length since the default collate function will eventually apply torch.stack
on it. I guess if you want to avoid this behavior with some specific items in your batch data, you should create a custom collate function.
Seeing how robust the default collate function is written, I don’t think I will likely face a situation where I will apply a custom collate function to a DataLoader
. But when I need to, I will know how it should be done.