diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 14741d862ae..9f30335f86b 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -149,7 +149,7 @@ class DataCollatorForLanguageModeling: ) -> torch.Tensor: # In order to accept both lists of lists and lists of Tensors if isinstance(examples[0], (list, tuple)): - examples = [torch.Tensor(e) for e in examples] + examples = [torch.tensor(e, dtype=torch.long) for e in examples] length_of_first = examples[0].size(0) are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) if are_tensors_same_length: