Fix convert_opt_original_pytorch_checkpoint_to_pytorch.py typo (#22526)

`load_checkpoint()` silently fails because `".qkj_proj." in key` is always `False`, but will eventually cause an error at `model.load_state_dict(state_dict)`.
This commit is contained in:
larekrow 2023-04-03 22:06:52 +08:00 committed by GitHub
parent a55a822adf
commit 9419f144ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -55,9 +55,9 @@ def load_checkpoint(checkpoint_path):
keys = list(sd.keys())
for key in keys:
if ".qkj_proj." in key:
if ".qkv_proj." in key:
value = sd[key]
# We split QKV in seperate Q,K,V
# We split QKV in separate Q,K,V
q_name = key.replace(".qkv_proj.", ".q_proj.")
k_name = key.replace(".qkv_proj.", ".k_proj.")