mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Place inputs on device when include_inputs_for_metrics is True (#18046)
This commit is contained in:
parent
870ff9e1da
commit
1b5ea74783
@ -2804,7 +2804,7 @@ class Trainer:
|
||||
|
||||
# Prediction step
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
|
||||
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.mark_step()
|
||||
@ -3352,7 +3352,7 @@ class Trainer:
|
||||
|
||||
for step, inputs in enumerate(dataloader):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
|
||||
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
|
||||
|
||||
if loss is not None:
|
||||
losses = loss.repeat(batch_size)
|
||||
|
Loading…
Reference in New Issue
Block a user