Update OPT conversion script to work for OPT-IML (#21519)

This commit is contained in:
Thomas Wang 2023-02-08 18:31:10 +01:00 committed by GitHub
parent fe616f35c8
commit 98d5b72727
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -53,6 +53,27 @@ def load_checkpoint(checkpoint_path):
if old_key in sd:
sd[new_key] = sd.pop(old_key)
keys = list(sd.keys())
for key in keys:
if ".qkj_proj." in key:
value = sd[key]
# We split QKV in seperate Q,K,V
q_name = key.replace(".qkv_proj.", ".q_proj.")
k_name = key.replace(".qkv_proj.", ".k_proj.")
v_name = key.replace(".qkv_proj.", ".v_proj.")
depth = value.shape[0]
assert depth % 3 == 0
# `SequeuceParallelTransformerBlock` has QKV weight is separated in K,V,Q despite the naming:
# https://cs.github.com/facebookresearch/metaseq/blob/51871bd73cd04c038f239ea2a26db1d7f6b37927/metaseq/modules/sequence_parallel_transformer_layer.py#L97
k, v, q = torch.split(value, depth // 3, dim=0)
sd[q_name] = q
sd[k_name] = k
sd[v_name] = v
del sd[key]
return sd