Add check for if num_items_in_batch is not None (#35102)

This commit is contained in:
Zach Mueller 2025-01-06 10:11:21 -05:00 committed by GitHub
parent 203e978826
commit a821b9c7ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5162,7 +5162,7 @@ class Trainer:
except (TypeError, AttributeError):
pass
if self.args.average_tokens_across_devices:
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
if torch.is_tensor(num_items_in_batch):