Fix Seq2SeqTrainer crash when BatchEncoding data is None (#31418)

avoiding crash when BatchEncoding data is None
This commit is contained in:
Dingli Yang 2024-07-08 17:51:23 +08:00 committed by GitHub
parent 06fd7972ac
commit c1cda0ee2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -800,7 +800,7 @@ class BatchEncoding(UserDict):
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self