mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Replace assert statements with exceptions (#24856)
* Changed AssertionError to ValueError try-except block was using AssesrtionError in except statement while the expected error is value error. Fixed the same. * Changed AssertionError to ValueError try-except block was using AssesrtionError in except statement while the expected error is ValueError. Fixed the same. Note: While raising the ValueError args are passed to it, but later added again while handling the error (See the code snippet) * Changed AssertionError to ValueError try-except block was using AssesrtionError in except statement while the expected error is ValueError. Fixed the same. Note: While raising the ValueError args are passed to it, but later added again while handling the error (See the code snippet) * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed assert statement to ValueError based * Changed assert statement to ValueError based * Changed assert statement to ValueError based * Changed incorrect error handling from AssertionError to ValueError * Undoed change from AssertionError to ValueError as it is not needed * Reverted back to using AssertionError as it is not necessary to make it into ValueError * Fixed erraneous comparision Changed == to != * Fixed erraneous comparision Changed == to != * formatted the code * Ran make fix-copies
This commit is contained in:
parent
12b908c659
commit
d0154015f7
@ -182,7 +182,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
try:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||
except AssertionError as e:
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print(f"Initialize PyTorch weight {name} from {original_name}")
|
||||
|
@ -337,18 +337,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
||||
else:
|
||||
raise ValueError("Can't verify logits as model is not supported")
|
||||
|
||||
assert logits.shape == expected_shape, "Shape of logits not as expected"
|
||||
if logits.shape != expected_shape:
|
||||
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||
if not has_lm_head:
|
||||
if is_semantic:
|
||||
assert torch.allclose(
|
||||
logits[0, :3, :3, :3], expected_logits, atol=1e-3
|
||||
), "First elements of logits not as expected"
|
||||
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
else:
|
||||
print("Predicted class idx:", logits.argmax(-1).item())
|
||||
assert torch.allclose(
|
||||
logits[0, :3], expected_logits, atol=1e-3
|
||||
), "First elements of logits not as expected"
|
||||
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
|
||||
|
||||
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||
raise ValueError("First elements of logits not as expected")
|
||||
if logits.argmax(-1).item() != expected_class_idx:
|
||||
raise ValueError("Predicted class index not as expected")
|
||||
|
||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||
|
@ -169,7 +169,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
try:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||
except AssertionError as e:
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f"Initialize PyTorch weight {name}")
|
||||
|
@ -227,7 +227,7 @@ def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False):
|
||||
raise ValueError(
|
||||
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}."
|
||||
)
|
||||
except AssertionError as e:
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
pt_weight_name = ".".join(pt_name)
|
||||
|
@ -96,10 +96,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert (
|
||||
pointer.shape == array.shape
|
||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
||||
except AssertionError as e:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f"Initialize PyTorch weight {name}")
|
||||
|
@ -135,7 +135,7 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
|
||||
try:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||
except AssertionError as e:
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
print(f"Initialize PyTorch weight {name}", original_name)
|
||||
|
@ -110,10 +110,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
try:
|
||||
assert (
|
||||
pointer.shape == array.shape
|
||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
||||
except AssertionError as e:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f"Initialize PyTorch weight {name}")
|
||||
|
@ -146,7 +146,7 @@ def load_tf_weights_in_roc_bert(model, config, tf_checkpoint_path):
|
||||
try:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||
except AssertionError as e:
|
||||
except ValueError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f"Initialize PyTorch weight {name}")
|
||||
|
Loading…
Reference in New Issue
Block a user