From 319beb64eb5f8d1f302160d2bd5bbbc7272b8896 Mon Sep 17 00:00:00 2001 From: Dhananjay Shettigar <39980717+djroxx2000@users.noreply.github.com> Date: Thu, 7 Oct 2021 18:39:01 +0530 Subject: [PATCH] #12789 Replace assert statements with exceptions (#13909) * #12789 Replace assert statements with exceptions * fix-copies: made copy changes to utils_qa.py in examples/pytorch/question-answering and examples/tensorflow/question-answering * minor refactor for clarity --- examples/flax/question-answering/utils_qa.py | 20 +++++++++++-------- .../pytorch/question-answering/utils_qa.py | 20 +++++++++++-------- .../tensorflow/question-answering/utils_qa.py | 20 +++++++++++-------- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/examples/flax/question-answering/utils_qa.py b/examples/flax/question-answering/utils_qa.py index 1157849c991..82b935f86f3 100644 --- a/examples/flax/question-answering/utils_qa.py +++ b/examples/flax/question-answering/utils_qa.py @@ -73,10 +73,12 @@ def postprocess_qa_predictions( log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): ``logging`` log level (e.g., ``logging.WARNING``) """ - assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + if len(predictions) != 2: + raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") all_start_logits, all_end_logits = predictions - assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -212,7 +214,8 @@ def postprocess_qa_predictions( # If we have an output_dir, let's save all those dicts. if output_dir is not None: - assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" @@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search( log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): ``logging`` log level (e.g., ``logging.WARNING``) """ - assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + if len(predictions) != 5: + raise ValueError("`predictions` should be a tuple with five elements.") start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions - assert len(predictions[0]) == len( - features - ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search( # If we have an output_dir, let's save all those dicts. if output_dir is not None: - assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" diff --git a/examples/pytorch/question-answering/utils_qa.py b/examples/pytorch/question-answering/utils_qa.py index 1157849c991..82b935f86f3 100644 --- a/examples/pytorch/question-answering/utils_qa.py +++ b/examples/pytorch/question-answering/utils_qa.py @@ -73,10 +73,12 @@ def postprocess_qa_predictions( log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): ``logging`` log level (e.g., ``logging.WARNING``) """ - assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + if len(predictions) != 2: + raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") all_start_logits, all_end_logits = predictions - assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -212,7 +214,8 @@ def postprocess_qa_predictions( # If we have an output_dir, let's save all those dicts. if output_dir is not None: - assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" @@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search( log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): ``logging`` log level (e.g., ``logging.WARNING``) """ - assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + if len(predictions) != 5: + raise ValueError("`predictions` should be a tuple with five elements.") start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions - assert len(predictions[0]) == len( - features - ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search( # If we have an output_dir, let's save all those dicts. if output_dir is not None: - assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py index 1157849c991..82b935f86f3 100644 --- a/examples/tensorflow/question-answering/utils_qa.py +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -73,10 +73,12 @@ def postprocess_qa_predictions( log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): ``logging`` log level (e.g., ``logging.WARNING``) """ - assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + if len(predictions) != 2: + raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).") all_start_logits, all_end_logits = predictions - assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -212,7 +214,8 @@ def postprocess_qa_predictions( # If we have an output_dir, let's save all those dicts. if output_dir is not None: - assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" @@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search( log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): ``logging`` log level (e.g., ``logging.WARNING``) """ - assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + if len(predictions) != 5: + raise ValueError("`predictions` should be a tuple with five elements.") start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions - assert len(predictions[0]) == len( - features - ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + if len(predictions[0]) != len(features): + raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.") # Build a map example to its corresponding features. example_id_to_index = {k: i for i, k in enumerate(examples["id"])} @@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search( # If we have an output_dir, let's save all those dicts. if output_dir is not None: - assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + if not os.path.isdir(output_dir): + raise EnvironmentError(f"{output_dir} is not a directory.") prediction_file = os.path.join( output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"