Replacement of 20 asserts with exceptions (#24757)

* initial replacements of asserts with errors/exceptions

* replace assert with exception in generation, align and bart

* reset formatting change

* reset another formatting issue

* Apply suggestion

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* don't touch this file

* change to 'is not False'

* fix type

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Bauke Brenninkmeijer 2023-07-12 13:45:09 +02:00 committed by GitHub
parent 430a04a75a
commit fc9e387dc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 32 deletions

View File

@ -242,10 +242,12 @@ class DataTrainingArguments:
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if extension not in ["csv", "json", "txt"]:
raise ValueError("train_file` should be a csv, json or text file.")
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
if extension not in ["csv", "json", "txt"]:
raise ValueError("`validation_file` should be a csv, json or text file.")
@flax.struct.dataclass

View File

@ -251,10 +251,12 @@ class DataTrainingArguments:
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if extension not in ["csv", "json", "txt"]:
raise ValueError("train_file` should be a csv, json or text file.")
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
if extension not in ["csv", "json", "txt"]:
raise ValueError("`validation_file` should be a csv, json or text file.")
class TrainState(train_state.TrainState):

View File

@ -147,11 +147,12 @@ class BenchmarkArguments:
return json.dumps(dataclasses.asdict(self), indent=2)
@property
def model_names(self):
assert len(self.models) > 0, (
"Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
" bert-base-cased` or `args.models = ['bert-base-cased']."
)
def model_names(self) -> List[str]:
if len(self.models) <= 0:
raise ValueError(
"Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
" bert-base-cased` or `args.models = ['bert-base-cased']."
)
return self.models
@property

View File

@ -60,9 +60,10 @@ def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
return func(*args, **kwargs)
if do_eager_mode is True:
assert (
use_xla is False
), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
if use_xla is not False:
raise ValueError(
"Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
)
return run_in_eager_mode
else:
return run_in_graph_mode
@ -88,13 +89,15 @@ class TensorFlowBenchmark(Benchmark):
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
# initialize GPU on separate process
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_speed(_inference)
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_speed(_train)
@ -105,7 +108,8 @@ class TensorFlowBenchmark(Benchmark):
if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
return self._measure_memory(_inference)
@ -115,7 +119,8 @@ class TensorFlowBenchmark(Benchmark):
if self.args.is_gpu:
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
strategy = self.args.strategy
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
if strategy is None:
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
return self._measure_memory(_train)
@ -164,9 +169,8 @@ class TensorFlowBenchmark(Benchmark):
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
config = self.config_dict[model_name]
assert (
self.args.eager_mode is False
), "Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`."
if self.args.eager_mode is not False:
raise ValueError("Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`.")
if self.args.fp16:
raise NotImplementedError("Mixed precision is currently not supported.")
@ -240,10 +244,11 @@ class TensorFlowBenchmark(Benchmark):
with self.args.strategy.scope():
try:
if self.args.trace_memory_line_by_line:
assert self.args.eager_mode, (
"`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
" consumption line by line."
)
if not self.args.eager_mode:
raise ValueError(
"`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
" consumption line by line."
)
trace = start_memory_tracing("transformers")
if self.args.is_tpu:

View File

@ -890,7 +890,8 @@ class Benchmark(ABC):
return
self.print_fn("Saving results to csv.")
with open(filename, mode="w") as csv_file:
assert len(self.args.model_names) > 0, f"At least 1 model should be defined, but got {self.model_names}"
if len(self.args.model_names) <= 0:
raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")
fieldnames = ["model", "batch_size", "sequence_length"]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])

View File

@ -90,7 +90,8 @@ def glue_compute_metrics(task_name, preds, labels):
def xnli_compute_metrics(task_name, preds, labels):
warnings.warn(DEPRECATION_WARNING, FutureWarning)
requires_backends(xnli_compute_metrics, "sklearn")
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
if len(preds) != len(labels):
raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
if task_name == "xnli":
return {"acc": simple_accuracy(preds, labels)}
else:

View File

@ -380,7 +380,8 @@ class BeamSearchScorer(BeamScorer):
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
if pad_token_id is None:
raise ValueError("`pad_token_id` has to be defined")
decoded.fill_(pad_token_id)
if indices is not None:
@ -855,7 +856,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
if pad_token_id is None:
raise ValueError("`pad_token_id` has to be defined")
decoded.fill_(pad_token_id)
# fill with hypotheses and eos_token_id if the latter fits in

View File

@ -346,8 +346,10 @@ def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_mod
text_features = tf.nn.l2_normalize(text_features, axis=-1)
# Check whether original and HF model outputs match -> np.allclose
assert np.allclose(image_features, hf_image_features, atol=1e-3), "The predicted image features are not the same."
assert np.allclose(text_features, hf_text_features, atol=1e-3), "The predicted text features are not the same."
if not np.allclose(image_features, hf_image_features, atol=1e-3):
raise ValueError("The predicted image features are not the same.")
if not np.allclose(text_features, hf_text_features, atol=1e-3):
raise ValueError("The predicted text features are not the same.")
print("Model outputs match!")
if save_model:

View File

@ -101,7 +101,10 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
config = BartConfig.from_pretrained(hf_checkpoint_name)
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
assert torch.eq(tokens, tokens2).all()
if not torch.eq(tokens, tokens2).all():
raise ValueError(
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
)
if checkpoint_path == "bart.large.mnli":
state_dict = bart.state_dict()
@ -130,8 +133,12 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
new_model_outputs = model.model(tokens)[0]
# Check results
assert fairseq_output.shape == new_model_outputs.shape
assert (fairseq_output == new_model_outputs).all().item()
if fairseq_output.shape != new_model_outputs.shape:
raise ValueError(
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
)
if (fairseq_output != new_model_outputs).any().item():
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)