From 1b5ea7478327e1b8c4dbcb85be52f345c8c1cf8d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 7 Jul 2022 08:17:49 -0400 Subject: [PATCH] Place inputs on device when include_inputs_for_metrics is True (#18046) --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5f698b86ddd..e817a9e1413 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2804,7 +2804,7 @@ class Trainer: # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if is_torch_tpu_available(): xm.mark_step() @@ -3352,7 +3352,7 @@ class Trainer: for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if loss is not None: losses = loss.repeat(batch_size)