Add TAPAS MLM-only models (#13408)

* Add conversion of TapasForMaskedLM

* Add copied from statements
This commit is contained in:
NielsRogge 2021-09-06 19:19:30 +02:00 committed by GitHub
parent 2dd975b235
commit 5642a555ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 77 additions and 12 deletions

View File

@ -1132,6 +1132,7 @@ if is_torch_available():
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
)
_import_structure["models.transfo_xl"].extend(
@ -2771,6 +2772,7 @@ if TYPE_CHECKING:
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -34,6 +34,7 @@ if is_torch_available():
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
"load_tf_weights_in_tapas",
]
@ -49,6 +50,7 @@ if TYPE_CHECKING:
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
load_tf_weights_in_tapas,
)
else:

View File

@ -81,22 +81,21 @@ def convert_tf_checkpoint_to_pytorch(
model = TapasForMaskedLM(config=config)
elif task == "INTERMEDIATE_PRETRAINING":
model = TapasModel(config=config)
else:
raise ValueError(f"Task {task} not supported.")
print(f"Building PyTorch model from configuration: {config}")
# Load weights from tf checkpoint
load_tf_weights_in_tapas(model, config, tf_checkpoint_path)
# Save pytorch-model (weights and configuration)
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path[:-17])
model.save_pretrained(pytorch_dump_path)
# Save tokenizer files
dir_name = r"C:\Users\niels.rogge\Documents\Python projecten\tensorflow\Tensorflow models\SQA\Base\tapas_sqa_inter_masklm_base_reset"
tokenizer = TapasTokenizer(vocab_file=dir_name + r"\vocab.txt", model_max_length=512)
print(f"Save tokenizer files to {pytorch_dump_path}")
tokenizer.save_pretrained(pytorch_dump_path[:-17])
tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512)
tokenizer.save_pretrained(pytorch_dump_path)
print("Used relative position embeddings:", model.config.reset_position_index_per_cell)

View File

@ -192,6 +192,11 @@ def load_tf_weights_in_tapas(model, config, tf_checkpoint_path):
if any(n in ["output_bias", "output_weights", "output_bias_cls", "output_weights_cls"] for n in name):
logger.info(f"Skipping {'/'.join(name)}")
continue
# in case the model is TapasForMaskedLM, we skip the pooler
if isinstance(model, TapasForMaskedLM):
if any(n in ["pooler"] for n in name):
logger.info(f"Skipping {'/'.join(name)}")
continue
# if first scope name starts with "bert", change it to "tapas"
if name[0] == "bert":
name[0] = "tapas"
@ -207,7 +212,10 @@ def load_tf_weights_in_tapas(model, config, tf_checkpoint_path):
pointer = getattr(pointer, "bias")
# cell selection heads
elif scope_names[0] == "output_bias":
pointer = getattr(pointer, "output_bias")
if not isinstance(model, TapasForMaskedLM):
pointer = getattr(pointer, "output_bias")
else:
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "output_weights")
elif scope_names[0] == "column_output_bias":
@ -697,6 +705,56 @@ class TapasPooler(nn.Module):
return pooled_output
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Tapas
class TapasPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Tapas
class TapasLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = TapasPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Tapas
class TapasOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = TapasLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class TapasPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
@ -947,15 +1005,15 @@ class TapasForMaskedLM(TapasPreTrainedModel):
super().__init__(config)
self.tapas = TapasModel(config, add_pooling_layer=False)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.cls = TapasOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.lm_head
return self.cls.predictions.decoder
def set_output_embeddings(self, word_embeddings):
self.lm_head = word_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@ -1020,7 +1078,7 @@ class TapasForMaskedLM(TapasPreTrainedModel):
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
prediction_scores = self.cls(sequence_output)
masked_lm_loss = None
if labels is not None:

View File

@ -3406,6 +3406,10 @@ class TapasPreTrainedModel:
requires_backends(cls, ["torch"])
def load_tf_weights_in_tapas(*args, **kwargs):
requires_backends(load_tf_weights_in_tapas, ["torch"])
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None