refactor: replace assert with ValueError (#14970)

This commit is contained in:
Jake Tae 2021-12-30 00:09:54 +09:00 committed by GitHub
parent 600496fa50
commit 04cddaf402
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -114,13 +114,8 @@ def load_tf_weights_in_bert_generation(
else:
model_pointer = model_pointer.weight
try:
assert (
model_pointer.shape == array.shape
), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (model_pointer.shape, array.shape)
raise
if model_pointer.shape != array.shape:
raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched")
logger.info(f"Initialize PyTorch weight {key}")
model_pointer.data = torch.from_numpy(array.astype(np.float32))