mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix gather when collecting 'num_input_tokens_seen' (#31974)
* Move token count to device before gathering * Run 'make style; make quality'
This commit is contained in:
parent
c22efa6196
commit
e391706420
@ -2245,12 +2245,17 @@ class Trainer:
|
||||
"a `main_input_name` attribute to the model class you are using."
|
||||
)
|
||||
else:
|
||||
input_device = inputs[main_input_name].device
|
||||
self.state.num_input_tokens_seen += torch.sum(
|
||||
self.accelerator.gather(
|
||||
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
|
||||
self.state.num_input_tokens_seen += (
|
||||
torch.sum(
|
||||
self.accelerator.gather(
|
||||
torch.tensor(
|
||||
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
|
||||
)
|
||||
)
|
||||
)
|
||||
).item()
|
||||
.cpu()
|
||||
.item()
|
||||
)
|
||||
if rng_to_sync:
|
||||
self._load_rng_state(resume_from_checkpoint)
|
||||
rng_to_sync = False
|
||||
|
Loading…
Reference in New Issue
Block a user