Variable length data processing method of Python dataloader

Time:2021-6-11

This official document has made it clear how to customize dataset classes in Python and how to use dataloader to load data iteratively. I won’t go into details here.

Now the question is:Sometimes, especially for NLP tasks, the input data may not be of fixed length. For example, the length of multiple sentences will not be consistent. At this time, when using dataloader to load data, sentences of variable length will be randomly segmented, which is definitely not OK.

resolventIs to override the collate of dataloader_ fnThe specific methods are as follows:

#If each sample is:
sample = {
	#The ID of each word in a sentence
	'token_list' : [5, 2, 4, 1, 9, 8],
	#Results y
	'label' : 5,
}


#Rewrite collate_ FN function, whose input is the sample data of a batch
def collate_fn(batch):
	#Because token_ List is a variable length data, so you need to use a list to load the token of the batch_ list
  token_lists = [item['token_list'] for item in batch]
  
  #Each label is an int. we take out all the labels in the batch and reassemble them
  labels = [item['label'] for item in batch]
  #Converting labels to tensor
  labels = torch.Tensor(labels)
  return {
    'token_list': token_lists,
    'label': labels,
  }


#When using dataloader to load data, pay attention to collate_ The FN parameter passes in an overridden function
DataLoader(trainset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)

Using the above method, we can ensure that the dataloader can load a batch of data, and the loaded data is the rewritten collate_ FN function returns the dictionary.

The above pytorch dataloader variable length data processing method is the whole content shared by Xiaobian. I hope it can give you a reference, and I hope you can support developer more.