mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix attribute error problem (#20765)
fix: 修复Trainer无法使用use_legacy_prediction_loop参数的问题 解决使用use_legacy_prediction_loop参数在predict阶段使用prediction_loop进行预测时,遇到AttributeError: 'PredictionOutput' object has no attribute 'num_samples'的问题 Co-authored-by: ZhouHang <zhouhang@idataway.com>
This commit is contained in:
parent
11745b4e45
commit
dfd818420d
@ -3514,7 +3514,7 @@ class Trainer:
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> PredictionOutput:
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
||||
|
||||
@ -3651,7 +3651,7 @@ class Trainer:
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
|
||||
|
||||
def _gather_and_numpify(self, tensors, name):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user