fix to ensure that returned tensors after the tokenization is Long (#7039)

* fix to ensure that returned tensors after the tokenization is Long

* fix to ensure that returned tensors after the tokenization is Long

Co-authored-by: Ashwin Geet Dsa <adsa@grvingt-6.nancy.grid5000.fr>
This commit is contained in:
Ashwin Geet Dsa 2020-09-10 17:04:03 +02:00 committed by GitHub
parent 9ccdb1d517
commit 66a5a6fda8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -149,7 +149,7 @@ class DataCollatorForLanguageModeling:
) -> torch.Tensor:
# In order to accept both lists of lists and lists of Tensors
if isinstance(examples[0], (list, tuple)):
examples = [torch.Tensor(e) for e in examples]
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length: