Extend Script to enable conversion of Encoder Only T5x Models to Pytorch (#20907)

* add converter for t5x_retrieval model

* update args

* Update src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* style  editing -> convert t5x to pytorch

* make style

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Ogundepo Odunayo 2023-01-23 08:41:43 -05:00 committed by GitHub
parent 91ff7efeeb
commit 96b2b2de12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -35,7 +35,7 @@ import torch
from flax import traverse_util
from t5x import checkpoints
from transformers import T5Config, T5ForConditionalGeneration
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
from transformers.utils import logging
@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
@ -110,50 +110,51 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
].T
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
if not is_encoder_only:
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
# Block i, layer 1 (Cross Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
# Block i, layer 1 (Cross Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
if split_mlp_wi:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
else:
new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
if split_mlp_wi:
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
else:
new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
"decoder/relpos_bias/rel_embedding"
].T
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
"decoder/relpos_bias/rel_embedding"
].T
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new
def make_state_dict(converted_params):
def make_state_dict(converted_params, is_encoder_only: bool):
"""Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
@ -162,35 +163,41 @@ def make_state_dict(converted_params):
if "encoder.embed_tokens.weight" not in state_dict:
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
if not is_encoder_only:
if "decoder.embed_tokens.weight" not in state_dict:
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
if "lm_head.weight" not in state_dict: # For old 1.0 models.
print("Using shared word embeddings as lm_head.")
state_dict["lm_head.weight"] = state_dict["shared.weight"]
if "lm_head.weight" not in state_dict: # For old 1.0 models.
print("Using shared word embeddings as lm_head.")
state_dict["lm_head.weight"] = state_dict["shared.weight"]
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
"""Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers)
state_dict = make_state_dict(converted)
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
state_dict = make_state_dict(converted, is_encoder_only)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
def convert_t5x_checkpoint_to_pytorch(
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
model = T5ForConditionalGeneration(config)
if is_encoder_only:
model = T5EncoderModel(config)
else:
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
@ -217,5 +224,10 @@ if __name__ == "__main__":
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
)
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
convert_t5x_checkpoint_to_pytorch(
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
)