mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Place inputs on device when include_inputs_for_metrics is True (#18046)
This commit is contained in:
parent
870ff9e1da
commit
1b5ea74783
@ -2804,7 +2804,7 @@ class Trainer:
|
|||||||
|
|
||||||
# Prediction step
|
# Prediction step
|
||||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||||
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
|
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
@ -3352,7 +3352,7 @@ class Trainer:
|
|||||||
|
|
||||||
for step, inputs in enumerate(dataloader):
|
for step, inputs in enumerate(dataloader):
|
||||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
||||||
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
|
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
|
||||||
|
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
losses = loss.repeat(batch_size)
|
losses = loss.repeat(batch_size)
|
||||||
|
Loading…
Reference in New Issue
Block a user