From 98d5b72727637ee068fa86667fb80fc6f693b50c Mon Sep 17 00:00:00 2001 From: Thomas Wang <24695242+thomasw21@users.noreply.github.com> Date: Wed, 8 Feb 2023 18:31:10 +0100 Subject: [PATCH] Update OPT conversion script to work for OPT-IML (#21519) --- ..._original_pytorch_checkpoint_to_pytorch.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py index ec1749daeff..2a84641ce07 100644 --- a/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py @@ -53,6 +53,27 @@ def load_checkpoint(checkpoint_path): if old_key in sd: sd[new_key] = sd.pop(old_key) + keys = list(sd.keys()) + for key in keys: + if ".qkj_proj." in key: + value = sd[key] + # We split QKV in seperate Q,K,V + + q_name = key.replace(".qkv_proj.", ".q_proj.") + k_name = key.replace(".qkv_proj.", ".k_proj.") + v_name = key.replace(".qkv_proj.", ".v_proj.") + + depth = value.shape[0] + assert depth % 3 == 0 + # `SequeuceParallelTransformerBlock` has QKV weight is separated in K,V,Q despite the naming: + # https://cs.github.com/facebookresearch/metaseq/blob/51871bd73cd04c038f239ea2a26db1d7f6b37927/metaseq/modules/sequence_parallel_transformer_layer.py#L97 + k, v, q = torch.split(value, depth // 3, dim=0) + + sd[q_name] = q + sd[k_name] = k + sd[v_name] = v + del sd[key] + return sd