diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 351b3f7ae85..f843d9be759 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2505,7 +2505,7 @@ class Trainer: for _ in range(total_updates): update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder - batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device) for i, inputs in enumerate(batch_samples): step += 1 do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch @@ -5216,7 +5216,7 @@ class Trainer: self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True ) - def get_batch_samples(self, epoch_iterator, num_batches): + def get_batch_samples(self, epoch_iterator, num_batches, device): batch_samples = [] num_items_in_batch = None for _ in range(num_batches): @@ -5232,11 +5232,12 @@ class Trainer: except (TypeError, AttributeError): pass - if self.args.average_tokens_across_devices and num_items_in_batch is not None: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + if num_items_in_batch is not None: + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum() - if torch.is_tensor(num_items_in_batch): - num_items_in_batch = num_items_in_batch.item() + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.to(device) return batch_samples, num_items_in_batch