From 04cddaf402591e9f5bdb5f116a111d829a0ce4f4 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Thu, 30 Dec 2021 00:09:54 +0900 Subject: [PATCH] refactor: replace `assert` with `ValueError` (#14970) --- .../models/bert_generation/modeling_bert_generation.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 5a910af959a..631d73784bb 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -114,13 +114,8 @@ def load_tf_weights_in_bert_generation( else: model_pointer = model_pointer.weight - try: - assert ( - model_pointer.shape == array.shape - ), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (model_pointer.shape, array.shape) - raise + if model_pointer.shape != array.shape: + raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched") logger.info(f"Initialize PyTorch weight {key}") model_pointer.data = torch.from_numpy(array.astype(np.float32))