Place inputs on device when include_inputs_for_metrics is True (#18046)

This commit is contained in:
Sylvain Gugger 2022-07-07 08:17:49 -04:00 committed by GitHub
parent 870ff9e1da
commit 1b5ea74783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2804,7 +2804,7 @@ class Trainer:
# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if is_torch_tpu_available():
xm.mark_step()
@ -3352,7 +3352,7 @@ class Trainer:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
inputs_decode = inputs["input_ids"] if args.include_inputs_for_metrics else None
inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
if loss is not None:
losses = loss.repeat(batch_size)