mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Update OPT conversion script to work for OPT-IML (#21519)
This commit is contained in:
parent
fe616f35c8
commit
98d5b72727
@ -53,6 +53,27 @@ def load_checkpoint(checkpoint_path):
|
||||
if old_key in sd:
|
||||
sd[new_key] = sd.pop(old_key)
|
||||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if ".qkj_proj." in key:
|
||||
value = sd[key]
|
||||
# We split QKV in seperate Q,K,V
|
||||
|
||||
q_name = key.replace(".qkv_proj.", ".q_proj.")
|
||||
k_name = key.replace(".qkv_proj.", ".k_proj.")
|
||||
v_name = key.replace(".qkv_proj.", ".v_proj.")
|
||||
|
||||
depth = value.shape[0]
|
||||
assert depth % 3 == 0
|
||||
# `SequeuceParallelTransformerBlock` has QKV weight is separated in K,V,Q despite the naming:
|
||||
# https://cs.github.com/facebookresearch/metaseq/blob/51871bd73cd04c038f239ea2a26db1d7f6b37927/metaseq/modules/sequence_parallel_transformer_layer.py#L97
|
||||
k, v, q = torch.split(value, depth // 3, dim=0)
|
||||
|
||||
sd[q_name] = q
|
||||
sd[k_name] = k
|
||||
sd[v_name] = v
|
||||
del sd[key]
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user