mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Fix copies
This commit is contained in:
parent
27b6ac4611
commit
276bc149d2
@ -38,7 +38,7 @@ def postprocess_qa_predictions(
|
|||||||
null_score_diff_threshold: float = 0.0,
|
null_score_diff_threshold: float = 0.0,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
is_world_process_zero: bool = True,
|
log_level: Optional[int] = logging.WARNING,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
|
||||||
@ -70,8 +70,8 @@ def postprocess_qa_predictions(
|
|||||||
answers, are saved in `output_dir`.
|
answers, are saved in `output_dir`.
|
||||||
prefix (:obj:`str`, `optional`):
|
prefix (:obj:`str`, `optional`):
|
||||||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
||||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
|
||||||
Whether this process is the main process or not (used to determine if logging/saves should be done).
|
``logging`` log level (e.g., ``logging.WARNING``)
|
||||||
"""
|
"""
|
||||||
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
|
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
|
||||||
all_start_logits, all_end_logits = predictions
|
all_start_logits, all_end_logits = predictions
|
||||||
@ -91,7 +91,7 @@ def postprocess_qa_predictions(
|
|||||||
scores_diff_json = collections.OrderedDict()
|
scores_diff_json = collections.OrderedDict()
|
||||||
|
|
||||||
# Logging.
|
# Logging.
|
||||||
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
|
logger.setLevel(log_level)
|
||||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||||
|
|
||||||
# Let's loop over all the examples!
|
# Let's loop over all the examples!
|
||||||
@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
end_n_top: int = 5,
|
end_n_top: int = 5,
|
||||||
output_dir: Optional[str] = None,
|
output_dir: Optional[str] = None,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
is_world_process_zero: bool = True,
|
log_level: Optional[int] = logging.WARNING,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
|
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
|
||||||
@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
answers, are saved in `output_dir`.
|
answers, are saved in `output_dir`.
|
||||||
prefix (:obj:`str`, `optional`):
|
prefix (:obj:`str`, `optional`):
|
||||||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
|
||||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
|
||||||
Whether this process is the main process or not (used to determine if logging/saves should be done).
|
``logging`` log level (e.g., ``logging.WARNING``)
|
||||||
"""
|
"""
|
||||||
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
|
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
|
||||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
|
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
|
||||||
@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
|
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
|
||||||
|
|
||||||
# Logging.
|
# Logging.
|
||||||
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN)
|
logger.setLevel(log_level)
|
||||||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
|
||||||
|
|
||||||
# Let's loop over all the examples!
|
# Let's loop over all the examples!
|
||||||
@ -413,14 +413,14 @@ def postprocess_qa_predictions_with_beam_search(
|
|||||||
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Saving predictions to {prediction_file}.")
|
logger.info(f"Saving predictions to {prediction_file}.")
|
||||||
with open(prediction_file, "w") as writer:
|
with open(prediction_file, "w") as writer:
|
||||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||||
print(f"Saving nbest_preds to {nbest_file}.")
|
logger.info(f"Saving nbest_preds to {nbest_file}.")
|
||||||
with open(nbest_file, "w") as writer:
|
with open(nbest_file, "w") as writer:
|
||||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||||
if version_2_with_negative:
|
if version_2_with_negative:
|
||||||
print(f"Saving null_odds to {null_odds_file}.")
|
logger.info(f"Saving null_odds to {null_odds_file}.")
|
||||||
with open(null_odds_file, "w") as writer:
|
with open(null_odds_file, "w") as writer:
|
||||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user