mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Replacement of 20 asserts with exceptions (#24757)
* initial replacements of asserts with errors/exceptions * replace assert with exception in generation, align and bart * reset formatting change * reset another formatting issue * Apply suggestion Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * don't touch this file * change to 'is not False' * fix type --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
430a04a75a
commit
fc9e387dc0
@ -242,10 +242,12 @@ class DataTrainingArguments:
|
|||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
if extension not in ["csv", "json", "txt"]:
|
||||||
|
raise ValueError("train_file` should be a csv, json or text file.")
|
||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
if extension not in ["csv", "json", "txt"]:
|
||||||
|
raise ValueError("`validation_file` should be a csv, json or text file.")
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
|
@ -251,10 +251,12 @@ class DataTrainingArguments:
|
|||||||
else:
|
else:
|
||||||
if self.train_file is not None:
|
if self.train_file is not None:
|
||||||
extension = self.train_file.split(".")[-1]
|
extension = self.train_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
if extension not in ["csv", "json", "txt"]:
|
||||||
|
raise ValueError("train_file` should be a csv, json or text file.")
|
||||||
if self.validation_file is not None:
|
if self.validation_file is not None:
|
||||||
extension = self.validation_file.split(".")[-1]
|
extension = self.validation_file.split(".")[-1]
|
||||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
if extension not in ["csv", "json", "txt"]:
|
||||||
|
raise ValueError("`validation_file` should be a csv, json or text file.")
|
||||||
|
|
||||||
|
|
||||||
class TrainState(train_state.TrainState):
|
class TrainState(train_state.TrainState):
|
||||||
|
@ -147,11 +147,12 @@ class BenchmarkArguments:
|
|||||||
return json.dumps(dataclasses.asdict(self), indent=2)
|
return json.dumps(dataclasses.asdict(self), indent=2)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_names(self):
|
def model_names(self) -> List[str]:
|
||||||
assert len(self.models) > 0, (
|
if len(self.models) <= 0:
|
||||||
"Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
|
raise ValueError(
|
||||||
" bert-base-cased` or `args.models = ['bert-base-cased']."
|
"Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
|
||||||
)
|
" bert-base-cased` or `args.models = ['bert-base-cased']."
|
||||||
|
)
|
||||||
return self.models
|
return self.models
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -60,9 +60,10 @@ def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if do_eager_mode is True:
|
if do_eager_mode is True:
|
||||||
assert (
|
if use_xla is not False:
|
||||||
use_xla is False
|
raise ValueError(
|
||||||
), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
|
"Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
|
||||||
|
)
|
||||||
return run_in_eager_mode
|
return run_in_eager_mode
|
||||||
else:
|
else:
|
||||||
return run_in_graph_mode
|
return run_in_graph_mode
|
||||||
@ -88,13 +89,15 @@ class TensorFlowBenchmark(Benchmark):
|
|||||||
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
||||||
# initialize GPU on separate process
|
# initialize GPU on separate process
|
||||||
strategy = self.args.strategy
|
strategy = self.args.strategy
|
||||||
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
|
if strategy is None:
|
||||||
|
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
||||||
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
||||||
return self._measure_speed(_inference)
|
return self._measure_speed(_inference)
|
||||||
|
|
||||||
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
||||||
strategy = self.args.strategy
|
strategy = self.args.strategy
|
||||||
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
|
if strategy is None:
|
||||||
|
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
||||||
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
||||||
return self._measure_speed(_train)
|
return self._measure_speed(_train)
|
||||||
|
|
||||||
@ -105,7 +108,8 @@ class TensorFlowBenchmark(Benchmark):
|
|||||||
if self.args.is_gpu:
|
if self.args.is_gpu:
|
||||||
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
|
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
|
||||||
strategy = self.args.strategy
|
strategy = self.args.strategy
|
||||||
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
|
if strategy is None:
|
||||||
|
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
||||||
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
||||||
return self._measure_memory(_inference)
|
return self._measure_memory(_inference)
|
||||||
|
|
||||||
@ -115,7 +119,8 @@ class TensorFlowBenchmark(Benchmark):
|
|||||||
if self.args.is_gpu:
|
if self.args.is_gpu:
|
||||||
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
|
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
|
||||||
strategy = self.args.strategy
|
strategy = self.args.strategy
|
||||||
assert strategy is not None, "A device strategy has to be initialized before using TensorFlow."
|
if strategy is None:
|
||||||
|
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
||||||
|
|
||||||
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
||||||
return self._measure_memory(_train)
|
return self._measure_memory(_train)
|
||||||
@ -164,9 +169,8 @@ class TensorFlowBenchmark(Benchmark):
|
|||||||
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
||||||
config = self.config_dict[model_name]
|
config = self.config_dict[model_name]
|
||||||
|
|
||||||
assert (
|
if self.args.eager_mode is not False:
|
||||||
self.args.eager_mode is False
|
raise ValueError("Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`.")
|
||||||
), "Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`."
|
|
||||||
|
|
||||||
if self.args.fp16:
|
if self.args.fp16:
|
||||||
raise NotImplementedError("Mixed precision is currently not supported.")
|
raise NotImplementedError("Mixed precision is currently not supported.")
|
||||||
@ -240,10 +244,11 @@ class TensorFlowBenchmark(Benchmark):
|
|||||||
with self.args.strategy.scope():
|
with self.args.strategy.scope():
|
||||||
try:
|
try:
|
||||||
if self.args.trace_memory_line_by_line:
|
if self.args.trace_memory_line_by_line:
|
||||||
assert self.args.eager_mode, (
|
if not self.args.eager_mode:
|
||||||
"`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
|
raise ValueError(
|
||||||
" consumption line by line."
|
"`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
|
||||||
)
|
" consumption line by line."
|
||||||
|
)
|
||||||
trace = start_memory_tracing("transformers")
|
trace = start_memory_tracing("transformers")
|
||||||
|
|
||||||
if self.args.is_tpu:
|
if self.args.is_tpu:
|
||||||
|
@ -890,7 +890,8 @@ class Benchmark(ABC):
|
|||||||
return
|
return
|
||||||
self.print_fn("Saving results to csv.")
|
self.print_fn("Saving results to csv.")
|
||||||
with open(filename, mode="w") as csv_file:
|
with open(filename, mode="w") as csv_file:
|
||||||
assert len(self.args.model_names) > 0, f"At least 1 model should be defined, but got {self.model_names}"
|
if len(self.args.model_names) <= 0:
|
||||||
|
raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")
|
||||||
|
|
||||||
fieldnames = ["model", "batch_size", "sequence_length"]
|
fieldnames = ["model", "batch_size", "sequence_length"]
|
||||||
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
|
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
|
||||||
|
@ -90,7 +90,8 @@ def glue_compute_metrics(task_name, preds, labels):
|
|||||||
def xnli_compute_metrics(task_name, preds, labels):
|
def xnli_compute_metrics(task_name, preds, labels):
|
||||||
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
||||||
requires_backends(xnli_compute_metrics, "sklearn")
|
requires_backends(xnli_compute_metrics, "sklearn")
|
||||||
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
if len(preds) != len(labels):
|
||||||
|
raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
|
||||||
if task_name == "xnli":
|
if task_name == "xnli":
|
||||||
return {"acc": simple_accuracy(preds, labels)}
|
return {"acc": simple_accuracy(preds, labels)}
|
||||||
else:
|
else:
|
||||||
|
@ -380,7 +380,8 @@ class BeamSearchScorer(BeamScorer):
|
|||||||
|
|
||||||
# shorter batches are padded if needed
|
# shorter batches are padded if needed
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
if pad_token_id is None:
|
||||||
|
raise ValueError("`pad_token_id` has to be defined")
|
||||||
decoded.fill_(pad_token_id)
|
decoded.fill_(pad_token_id)
|
||||||
|
|
||||||
if indices is not None:
|
if indices is not None:
|
||||||
@ -855,7 +856,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
|
|||||||
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
||||||
# shorter batches are padded if needed
|
# shorter batches are padded if needed
|
||||||
if sent_lengths.min().item() != sent_lengths.max().item():
|
if sent_lengths.min().item() != sent_lengths.max().item():
|
||||||
assert pad_token_id is not None, "`pad_token_id` has to be defined"
|
if pad_token_id is None:
|
||||||
|
raise ValueError("`pad_token_id` has to be defined")
|
||||||
decoded.fill_(pad_token_id)
|
decoded.fill_(pad_token_id)
|
||||||
|
|
||||||
# fill with hypotheses and eos_token_id if the latter fits in
|
# fill with hypotheses and eos_token_id if the latter fits in
|
||||||
|
@ -346,8 +346,10 @@ def convert_align_checkpoint(checkpoint_path, pytorch_dump_folder_path, save_mod
|
|||||||
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
text_features = tf.nn.l2_normalize(text_features, axis=-1)
|
||||||
|
|
||||||
# Check whether original and HF model outputs match -> np.allclose
|
# Check whether original and HF model outputs match -> np.allclose
|
||||||
assert np.allclose(image_features, hf_image_features, atol=1e-3), "The predicted image features are not the same."
|
if not np.allclose(image_features, hf_image_features, atol=1e-3):
|
||||||
assert np.allclose(text_features, hf_text_features, atol=1e-3), "The predicted text features are not the same."
|
raise ValueError("The predicted image features are not the same.")
|
||||||
|
if not np.allclose(text_features, hf_text_features, atol=1e-3):
|
||||||
|
raise ValueError("The predicted text features are not the same.")
|
||||||
print("Model outputs match!")
|
print("Model outputs match!")
|
||||||
|
|
||||||
if save_model:
|
if save_model:
|
||||||
|
@ -101,7 +101,10 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
|||||||
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
config = BartConfig.from_pretrained(hf_checkpoint_name)
|
||||||
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
tokens = bart.encode(SAMPLE_TEXT).unsqueeze(0)
|
||||||
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
tokens2 = BartTokenizer.from_pretrained(hf_checkpoint_name).encode(SAMPLE_TEXT, return_tensors="pt").unsqueeze(0)
|
||||||
assert torch.eq(tokens, tokens2).all()
|
if not torch.eq(tokens, tokens2).all():
|
||||||
|
raise ValueError(
|
||||||
|
f"converted tokenizer and pretrained tokenizer returned different output: {tokens} != {tokens2}"
|
||||||
|
)
|
||||||
|
|
||||||
if checkpoint_path == "bart.large.mnli":
|
if checkpoint_path == "bart.large.mnli":
|
||||||
state_dict = bart.state_dict()
|
state_dict = bart.state_dict()
|
||||||
@ -130,8 +133,12 @@ def convert_bart_checkpoint(checkpoint_path, pytorch_dump_folder_path, hf_checkp
|
|||||||
new_model_outputs = model.model(tokens)[0]
|
new_model_outputs = model.model(tokens)[0]
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
assert fairseq_output.shape == new_model_outputs.shape
|
if fairseq_output.shape != new_model_outputs.shape:
|
||||||
assert (fairseq_output == new_model_outputs).all().item()
|
raise ValueError(
|
||||||
|
f"`fairseq_output` shape and `new_model_output` shape are different: {fairseq_output.shape=}, {new_model_outputs.shape}"
|
||||||
|
)
|
||||||
|
if (fairseq_output != new_model_outputs).any().item():
|
||||||
|
raise ValueError("Some values in `fairseq_output` are different from `new_model_outputs`")
|
||||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
model.save_pretrained(pytorch_dump_folder_path)
|
model.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user