mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
refactor: replace assert
with ValueError
(#14970)
This commit is contained in:
parent
600496fa50
commit
04cddaf402
@ -114,13 +114,8 @@ def load_tf_weights_in_bert_generation(
|
||||
else:
|
||||
model_pointer = model_pointer.weight
|
||||
|
||||
try:
|
||||
assert (
|
||||
model_pointer.shape == array.shape
|
||||
), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched"
|
||||
except AssertionError as e:
|
||||
e.args += (model_pointer.shape, array.shape)
|
||||
raise
|
||||
if model_pointer.shape != array.shape:
|
||||
raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched")
|
||||
logger.info(f"Initialize PyTorch weight {key}")
|
||||
|
||||
model_pointer.data = torch.from_numpy(array.astype(np.float32))
|
||||
|
Loading…
Reference in New Issue
Block a user