Fix DataCollatorForSeq2Seq when labels are supplied as Numpy array instead of list (#13582)

* Fix issue when labels are supplied as Numpy array instead of list

* Fix issue when labels are supplied as Numpy array instead of list
This commit is contained in:
Matt 2021-09-16 15:35:57 +01:00 committed by GitHub
parent 421929b556
commit 5c5937182a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -517,6 +517,8 @@ class DataCollatorForSeq2Seq:
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
import numpy as np
if return_tensors is None:
return_tensors = self.return_tensors
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
@ -527,9 +529,14 @@ class DataCollatorForSeq2Seq:
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
features = self.tokenizer.pad(
features,