mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Replace assert statements with exceptions (#24856)
* Changed AssertionError to ValueError try-except block was using AssesrtionError in except statement while the expected error is value error. Fixed the same. * Changed AssertionError to ValueError try-except block was using AssesrtionError in except statement while the expected error is ValueError. Fixed the same. Note: While raising the ValueError args are passed to it, but later added again while handling the error (See the code snippet) * Changed AssertionError to ValueError try-except block was using AssesrtionError in except statement while the expected error is ValueError. Fixed the same. Note: While raising the ValueError args are passed to it, but later added again while handling the error (See the code snippet) * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed AssertionError to ValueError * Changed assert statement to ValueError based * Changed assert statement to ValueError based * Changed assert statement to ValueError based * Changed incorrect error handling from AssertionError to ValueError * Undoed change from AssertionError to ValueError as it is not needed * Reverted back to using AssertionError as it is not necessary to make it into ValueError * Fixed erraneous comparision Changed == to != * Fixed erraneous comparision Changed == to != * formatted the code * Ran make fix-copies
This commit is contained in:
parent
12b908c659
commit
d0154015f7
@ -182,7 +182,7 @@ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
|||||||
try:
|
try:
|
||||||
if pointer.shape != array.shape:
|
if pointer.shape != array.shape:
|
||||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
except AssertionError as e:
|
except ValueError as e:
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
print(f"Initialize PyTorch weight {name} from {original_name}")
|
print(f"Initialize PyTorch weight {name} from {original_name}")
|
||||||
|
@ -337,18 +337,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Can't verify logits as model is not supported")
|
raise ValueError("Can't verify logits as model is not supported")
|
||||||
|
|
||||||
assert logits.shape == expected_shape, "Shape of logits not as expected"
|
if logits.shape != expected_shape:
|
||||||
|
raise ValueError(f"Shape of logits not as expected. {logits.shape=}, {expected_shape=}")
|
||||||
if not has_lm_head:
|
if not has_lm_head:
|
||||||
if is_semantic:
|
if is_semantic:
|
||||||
assert torch.allclose(
|
if not torch.allclose(logits[0, :3, :3, :3], expected_logits, atol=1e-3):
|
||||||
logits[0, :3, :3, :3], expected_logits, atol=1e-3
|
raise ValueError("First elements of logits not as expected")
|
||||||
), "First elements of logits not as expected"
|
|
||||||
else:
|
else:
|
||||||
print("Predicted class idx:", logits.argmax(-1).item())
|
print("Predicted class idx:", logits.argmax(-1).item())
|
||||||
assert torch.allclose(
|
|
||||||
logits[0, :3], expected_logits, atol=1e-3
|
if not torch.allclose(logits[0, :3], expected_logits, atol=1e-3):
|
||||||
), "First elements of logits not as expected"
|
raise ValueError("First elements of logits not as expected")
|
||||||
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
|
if logits.argmax(-1).item() != expected_class_idx:
|
||||||
|
raise ValueError("Predicted class index not as expected")
|
||||||
|
|
||||||
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
||||||
print(f"Saving model to {pytorch_dump_folder_path}")
|
print(f"Saving model to {pytorch_dump_folder_path}")
|
||||||
|
@ -169,7 +169,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||||||
try:
|
try:
|
||||||
if pointer.shape != array.shape:
|
if pointer.shape != array.shape:
|
||||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
except AssertionError as e:
|
except ValueError as e:
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
logger.info(f"Initialize PyTorch weight {name}")
|
logger.info(f"Initialize PyTorch weight {name}")
|
||||||
|
@ -227,7 +227,7 @@ def load_tf_weights_in_big_bird(model, tf_checkpoint_path, is_trivia_qa=False):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}."
|
f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched of {txt_name}."
|
||||||
)
|
)
|
||||||
except AssertionError as e:
|
except ValueError as e:
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
pt_weight_name = ".".join(pt_name)
|
pt_weight_name = ".".join(pt_name)
|
||||||
|
@ -96,10 +96,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
|||||||
num = int(scope_names[1])
|
num = int(scope_names[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
try:
|
try:
|
||||||
assert (
|
if pointer.shape != array.shape:
|
||||||
pointer.shape == array.shape
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
except ValueError as e:
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
logger.info(f"Initialize PyTorch weight {name}")
|
logger.info(f"Initialize PyTorch weight {name}")
|
||||||
|
@ -135,7 +135,7 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
|
|||||||
try:
|
try:
|
||||||
if pointer.shape != array.shape:
|
if pointer.shape != array.shape:
|
||||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
except AssertionError as e:
|
except ValueError as e:
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
print(f"Initialize PyTorch weight {name}", original_name)
|
print(f"Initialize PyTorch weight {name}", original_name)
|
||||||
|
@ -110,10 +110,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
|||||||
num = int(scope_names[1])
|
num = int(scope_names[1])
|
||||||
pointer = pointer[num]
|
pointer = pointer[num]
|
||||||
try:
|
try:
|
||||||
assert (
|
if pointer.shape != array.shape:
|
||||||
pointer.shape == array.shape
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
except ValueError as e:
|
||||||
except AssertionError as e:
|
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
logger.info(f"Initialize PyTorch weight {name}")
|
logger.info(f"Initialize PyTorch weight {name}")
|
||||||
|
@ -146,7 +146,7 @@ def load_tf_weights_in_roc_bert(model, config, tf_checkpoint_path):
|
|||||||
try:
|
try:
|
||||||
if pointer.shape != array.shape:
|
if pointer.shape != array.shape:
|
||||||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
|
||||||
except AssertionError as e:
|
except ValueError as e:
|
||||||
e.args += (pointer.shape, array.shape)
|
e.args += (pointer.shape, array.shape)
|
||||||
raise
|
raise
|
||||||
logger.info(f"Initialize PyTorch weight {name}")
|
logger.info(f"Initialize PyTorch weight {name}")
|
||||||
|
Loading…
Reference in New Issue
Block a user