fix error message (#12148)

This commit is contained in:
Suraj Patil 2021-06-14 18:42:18 +05:30 committed by GitHub
parent 9de62cfbce
commit fe3576488a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -106,7 +106,8 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
if pt_tuple_key in random_flax_state_dict:
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
raise ValueError(
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
)
# also add unexpected weight so that warning is thrown