Fix attribute error problem (#20765)

fix: 修复Trainer无法使用use_legacy_prediction_loop参数的问题

解决使用use_legacy_prediction_loop参数在predict阶段使用prediction_loop进行预测时,遇到AttributeError: 'PredictionOutput' object has no attribute 'num_samples'的问题

Co-authored-by: ZhouHang <zhouhang@idataway.com>
This commit is contained in:
casuallyName 2022-12-14 22:26:06 +08:00 committed by GitHub
parent 11745b4e45
commit dfd818420d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3514,7 +3514,7 @@ class Trainer:
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> PredictionOutput:
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
@ -3651,7 +3651,7 @@ class Trainer:
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
def _gather_and_numpify(self, tensors, name):
"""