mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Extend Script to enable conversion of Encoder Only T5x Models to Pytorch (#20907)
* add converter for t5x_retrieval model * update args * Update src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * style editing -> convert t5x to pytorch * make style Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
91ff7efeeb
commit
96b2b2de12
@ -35,7 +35,7 @@ import torch
|
||||
|
||||
from flax import traverse_util
|
||||
from t5x import checkpoints
|
||||
from transformers import T5Config, T5ForConditionalGeneration
|
||||
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ def t5x_layer_norm_lookup(params, i, prefix, layer_name):
|
||||
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
|
||||
|
||||
|
||||
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
|
||||
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, is_encoder_only: bool):
|
||||
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
|
||||
old = traverse_util.flatten_dict(variables["target"])
|
||||
old = {"/".join(k): v for k, v in old.items()}
|
||||
@ -110,50 +110,51 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int):
|
||||
].T
|
||||
new["encoder.final_layer_norm.weight"] = old["encoder/encoder_norm/scale"]
|
||||
|
||||
# Decoder.
|
||||
for i in range(num_layers):
|
||||
# Block i, layer 0 (Self Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
|
||||
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
|
||||
if not is_encoder_only:
|
||||
# Decoder.
|
||||
for i in range(num_layers):
|
||||
# Block i, layer 0 (Self Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
|
||||
new[f"decoder.block.{i}.layer.0.layer_norm.weight"] = layer_norm
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.k.weight"] = k.T
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.o.weight"] = o.T
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.q.weight"] = q.T
|
||||
new[f"decoder.block.{i}.layer.0.SelfAttention.v.weight"] = v.T
|
||||
|
||||
# Block i, layer 1 (Cross Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
|
||||
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
|
||||
# Block i, layer 1 (Cross Attention).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_cross_attention_layer_norm")
|
||||
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "encoder_decoder_attention")
|
||||
new[f"decoder.block.{i}.layer.1.layer_norm.weight"] = layer_norm
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"] = k.T
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"] = o.T
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"] = q.T
|
||||
new[f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"] = v.T
|
||||
|
||||
# Block i, layer 2 (MLP).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
|
||||
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
|
||||
if split_mlp_wi:
|
||||
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
|
||||
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
|
||||
else:
|
||||
new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
|
||||
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
|
||||
# Block i, layer 2 (MLP).
|
||||
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
|
||||
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
|
||||
new[f"decoder.block.{i}.layer.2.layer_norm.weight"] = layer_norm
|
||||
if split_mlp_wi:
|
||||
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"] = wi[0].T
|
||||
new[f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"] = wi[1].T
|
||||
else:
|
||||
new[f"encoder.block.{i}.layer.2.DenseReluDense.wi.weight"] = wi.T
|
||||
new[f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"] = wo.T
|
||||
|
||||
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
|
||||
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
|
||||
"decoder/relpos_bias/rel_embedding"
|
||||
].T
|
||||
new["decoder.final_layer_norm.weight"] = old["decoder/decoder_norm/scale"]
|
||||
new["decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"] = old[
|
||||
"decoder/relpos_bias/rel_embedding"
|
||||
].T
|
||||
|
||||
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
|
||||
if "decoder/logits_dense/kernel" in old:
|
||||
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
|
||||
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
|
||||
if "decoder/logits_dense/kernel" in old:
|
||||
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
|
||||
|
||||
return new
|
||||
|
||||
|
||||
def make_state_dict(converted_params):
|
||||
def make_state_dict(converted_params, is_encoder_only: bool):
|
||||
"""Prepares a state dict for the PyTorch model."""
|
||||
# Make a state dict with torch tensors.
|
||||
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
|
||||
@ -162,35 +163,41 @@ def make_state_dict(converted_params):
|
||||
if "encoder.embed_tokens.weight" not in state_dict:
|
||||
state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"]
|
||||
|
||||
if "decoder.embed_tokens.weight" not in state_dict:
|
||||
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
|
||||
if not is_encoder_only:
|
||||
if "decoder.embed_tokens.weight" not in state_dict:
|
||||
state_dict["decoder.embed_tokens.weight"] = state_dict["shared.weight"]
|
||||
|
||||
if "lm_head.weight" not in state_dict: # For old 1.0 models.
|
||||
print("Using shared word embeddings as lm_head.")
|
||||
state_dict["lm_head.weight"] = state_dict["shared.weight"]
|
||||
if "lm_head.weight" not in state_dict: # For old 1.0 models.
|
||||
print("Using shared word embeddings as lm_head.")
|
||||
state_dict["lm_head.weight"] = state_dict["shared.weight"]
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
|
||||
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only):
|
||||
"""Replaces the params in model witht the T5X converted params."""
|
||||
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
|
||||
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers)
|
||||
state_dict = make_state_dict(converted)
|
||||
converted = convert_t5x_to_pytorch(variables, num_layers=config.num_layers, is_encoder_only=is_encoder_only)
|
||||
state_dict = make_state_dict(converted, is_encoder_only)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
|
||||
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
|
||||
def convert_t5x_checkpoint_to_pytorch(
|
||||
t5x_checkpoint_path, config_file, pytorch_dump_path, is_encoder_only: bool = False
|
||||
):
|
||||
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
|
||||
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
|
||||
model = T5ForConditionalGeneration(config)
|
||||
if is_encoder_only:
|
||||
model = T5EncoderModel(config)
|
||||
else:
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
|
||||
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path, is_encoder_only)
|
||||
|
||||
# Save pytorch-model
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
@ -217,5 +224,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--is_encoder_only", action="store_true", help="Check if the model is encoder-decoder model", default=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
|
||||
convert_t5x_checkpoint_to_pytorch(
|
||||
args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path, args.is_encoder_only
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user