mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix convert_opt_original_pytorch_checkpoint_to_pytorch.py typo (#22526)
`load_checkpoint()` silently fails because `".qkj_proj." in key` is always `False`, but will eventually cause an error at `model.load_state_dict(state_dict)`.
This commit is contained in:
parent
a55a822adf
commit
9419f144ad
@ -55,9 +55,9 @@ def load_checkpoint(checkpoint_path):
|
||||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if ".qkj_proj." in key:
|
||||
if ".qkv_proj." in key:
|
||||
value = sd[key]
|
||||
# We split QKV in seperate Q,K,V
|
||||
# We split QKV in separate Q,K,V
|
||||
|
||||
q_name = key.replace(".qkv_proj.", ".q_proj.")
|
||||
k_name = key.replace(".qkv_proj.", ".k_proj.")
|
||||
|
Loading…
Reference in New Issue
Block a user