Remove call to .item in get_batch_samples (#36861)

This commit is contained in:
regisss 2025-03-21 03:14:26 -06:00 committed by GitHub
parent 6bb8565f0c
commit 0adbc873d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2505,7 +2505,7 @@ class Trainer:
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
for i, inputs in enumerate(batch_samples):
step += 1
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
@ -5216,7 +5216,7 @@ class Trainer:
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
)
def get_batch_samples(self, epoch_iterator, num_batches):
def get_batch_samples(self, epoch_iterator, num_batches, device):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
@ -5232,11 +5232,12 @@ class Trainer:
except (TypeError, AttributeError):
pass
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
if num_items_in_batch is not None:
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.item()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(device)
return batch_samples, num_items_in_batch