diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py index 2f8f0a60c45..1157849c991 100644 --- a/examples/tensorflow/question-answering/utils_qa.py +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -38,7 +38,7 @@ def postprocess_qa_predictions( null_score_diff_threshold: float = 0.0, output_dir: Optional[str] = None, prefix: Optional[str] = None, - is_world_process_zero: bool = True, + log_level: Optional[int] = logging.WARNING, ): """ Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the @@ -70,8 +70,8 @@ def postprocess_qa_predictions( answers, are saved in `output_dir`. prefix (:obj:`str`, `optional`): If provided, the dictionaries mentioned above are saved with `prefix` added to their names. - is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether this process is the main process or not (used to determine if logging/saves should be done). + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) """ assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." all_start_logits, all_end_logits = predictions @@ -91,7 +91,7 @@ def postprocess_qa_predictions( scores_diff_json = collections.OrderedDict() # Logging. - logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) + logger.setLevel(log_level) logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Let's loop over all the examples! @@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search( end_n_top: int = 5, output_dir: Optional[str] = None, prefix: Optional[str] = None, - is_world_process_zero: bool = True, + log_level: Optional[int] = logging.WARNING, ): """ Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the @@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search( answers, are saved in `output_dir`. prefix (:obj:`str`, `optional`): If provided, the dictionaries mentioned above are saved with `prefix` added to their names. - is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether this process is the main process or not (used to determine if logging/saves should be done). + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) """ assert len(predictions) == 5, "`predictions` should be a tuple with five elements." start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions @@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search( scores_diff_json = collections.OrderedDict() if version_2_with_negative else None # Logging. - logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) + logger.setLevel(log_level) logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Let's loop over all the examples! @@ -413,14 +413,14 @@ def postprocess_qa_predictions_with_beam_search( output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" ) - print(f"Saving predictions to {prediction_file}.") + logger.info(f"Saving predictions to {prediction_file}.") with open(prediction_file, "w") as writer: writer.write(json.dumps(all_predictions, indent=4) + "\n") - print(f"Saving nbest_preds to {nbest_file}.") + logger.info(f"Saving nbest_preds to {nbest_file}.") with open(nbest_file, "w") as writer: writer.write(json.dumps(all_nbest_json, indent=4) + "\n") if version_2_with_negative: - print(f"Saving null_odds to {null_odds_file}.") + logger.info(f"Saving null_odds to {null_odds_file}.") with open(null_odds_file, "w") as writer: writer.write(json.dumps(scores_diff_json, indent=4) + "\n")