Fix gather when collecting 'num_input_tokens_seen' (#31974)

* Move token count to device before gathering

* Run 'make style; make quality'
This commit is contained in:
Alexander Wettig 2024-07-16 20:35:10 +02:00 committed by GitHub
parent c22efa6196
commit e391706420
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2245,12 +2245,17 @@ class Trainer:
"a `main_input_name` attribute to the model class you are using." "a `main_input_name` attribute to the model class you are using."
) )
else: else:
input_device = inputs[main_input_name].device self.state.num_input_tokens_seen += (
self.state.num_input_tokens_seen += torch.sum( torch.sum(
self.accelerator.gather( self.accelerator.gather(
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
) )
).item() .cpu()
.item()
)
if rng_to_sync: if rng_to_sync:
self._load_rng_state(resume_from_checkpoint) self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False rng_to_sync = False