diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5f698b86ddd..e817a9e1413 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2804,7 +2804,7 @@ class Trainer: # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if is_torch_tpu_available(): xm.mark_step() @@ -3352,7 +3352,7 @@ class Trainer: for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None if loss is not None: losses = loss.repeat(batch_size)