mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Fix gather when collecting 'num_input_tokens_seen' (#31974)
* Move token count to device before gathering * Run 'make style; make quality'
This commit is contained in:
parent
c22efa6196
commit
e391706420
@ -2245,12 +2245,17 @@ class Trainer:
|
|||||||
"a `main_input_name` attribute to the model class you are using."
|
"a `main_input_name` attribute to the model class you are using."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_device = inputs[main_input_name].device
|
self.state.num_input_tokens_seen += (
|
||||||
self.state.num_input_tokens_seen += torch.sum(
|
torch.sum(
|
||||||
self.accelerator.gather(
|
self.accelerator.gather(
|
||||||
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
|
torch.tensor(
|
||||||
|
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
).item()
|
.cpu()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
if rng_to_sync:
|
if rng_to_sync:
|
||||||
self._load_rng_state(resume_from_checkpoint)
|
self._load_rng_state(resume_from_checkpoint)
|
||||||
rng_to_sync = False
|
rng_to_sync = False
|
||||||
|
Loading…
Reference in New Issue
Block a user