mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove call to .item
in get_batch_samples
(#36861)
This commit is contained in:
parent
6bb8565f0c
commit
0adbc873d0
@ -2505,7 +2505,7 @@ class Trainer:
|
||||
for _ in range(total_updates):
|
||||
update_step += 1
|
||||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
|
||||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
|
||||
for i, inputs in enumerate(batch_samples):
|
||||
step += 1
|
||||
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
|
||||
@ -5216,7 +5216,7 @@ class Trainer:
|
||||
self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
|
||||
)
|
||||
|
||||
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, device):
|
||||
batch_samples = []
|
||||
num_items_in_batch = None
|
||||
for _ in range(num_batches):
|
||||
@ -5232,11 +5232,12 @@ class Trainer:
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
if self.args.average_tokens_across_devices and num_items_in_batch is not None:
|
||||
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
|
||||
if num_items_in_batch is not None:
|
||||
if self.args.average_tokens_across_devices:
|
||||
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
|
||||
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.item()
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(device)
|
||||
|
||||
return batch_samples, num_items_in_batch
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user