mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add TAPAS MLM-only models (#13408)
* Add conversion of TapasForMaskedLM * Add copied from statements
This commit is contained in:
parent
2dd975b235
commit
5642a555ae
@ -1132,6 +1132,7 @@ if is_torch_available():
|
||||
"TapasForSequenceClassification",
|
||||
"TapasModel",
|
||||
"TapasPreTrainedModel",
|
||||
"load_tf_weights_in_tapas",
|
||||
]
|
||||
)
|
||||
_import_structure["models.transfo_xl"].extend(
|
||||
@ -2771,6 +2772,7 @@ if TYPE_CHECKING:
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
TapasPreTrainedModel,
|
||||
load_tf_weights_in_tapas,
|
||||
)
|
||||
from .models.transfo_xl import (
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
@ -34,6 +34,7 @@ if is_torch_available():
|
||||
"TapasForSequenceClassification",
|
||||
"TapasModel",
|
||||
"TapasPreTrainedModel",
|
||||
"load_tf_weights_in_tapas",
|
||||
]
|
||||
|
||||
|
||||
@ -49,6 +50,7 @@ if TYPE_CHECKING:
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
TapasPreTrainedModel,
|
||||
load_tf_weights_in_tapas,
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -81,22 +81,21 @@ def convert_tf_checkpoint_to_pytorch(
|
||||
model = TapasForMaskedLM(config=config)
|
||||
elif task == "INTERMEDIATE_PRETRAINING":
|
||||
model = TapasModel(config=config)
|
||||
else:
|
||||
raise ValueError(f"Task {task} not supported.")
|
||||
|
||||
print(f"Building PyTorch model from configuration: {config}")
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_tapas(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model (weights and configuration)
|
||||
print(f"Save PyTorch model to {pytorch_dump_path}")
|
||||
model.save_pretrained(pytorch_dump_path[:-17])
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
# Save tokenizer files
|
||||
dir_name = r"C:\Users\niels.rogge\Documents\Python projecten\tensorflow\Tensorflow models\SQA\Base\tapas_sqa_inter_masklm_base_reset"
|
||||
tokenizer = TapasTokenizer(vocab_file=dir_name + r"\vocab.txt", model_max_length=512)
|
||||
|
||||
print(f"Save tokenizer files to {pytorch_dump_path}")
|
||||
tokenizer.save_pretrained(pytorch_dump_path[:-17])
|
||||
tokenizer = TapasTokenizer(vocab_file=tf_checkpoint_path[:-10] + "vocab.txt", model_max_length=512)
|
||||
tokenizer.save_pretrained(pytorch_dump_path)
|
||||
|
||||
print("Used relative position embeddings:", model.config.reset_position_index_per_cell)
|
||||
|
||||
|
@ -192,6 +192,11 @@ def load_tf_weights_in_tapas(model, config, tf_checkpoint_path):
|
||||
if any(n in ["output_bias", "output_weights", "output_bias_cls", "output_weights_cls"] for n in name):
|
||||
logger.info(f"Skipping {'/'.join(name)}")
|
||||
continue
|
||||
# in case the model is TapasForMaskedLM, we skip the pooler
|
||||
if isinstance(model, TapasForMaskedLM):
|
||||
if any(n in ["pooler"] for n in name):
|
||||
logger.info(f"Skipping {'/'.join(name)}")
|
||||
continue
|
||||
# if first scope name starts with "bert", change it to "tapas"
|
||||
if name[0] == "bert":
|
||||
name[0] = "tapas"
|
||||
@ -207,7 +212,10 @@ def load_tf_weights_in_tapas(model, config, tf_checkpoint_path):
|
||||
pointer = getattr(pointer, "bias")
|
||||
# cell selection heads
|
||||
elif scope_names[0] == "output_bias":
|
||||
pointer = getattr(pointer, "output_bias")
|
||||
if not isinstance(model, TapasForMaskedLM):
|
||||
pointer = getattr(pointer, "output_bias")
|
||||
else:
|
||||
pointer = getattr(pointer, "bias")
|
||||
elif scope_names[0] == "output_weights":
|
||||
pointer = getattr(pointer, "output_weights")
|
||||
elif scope_names[0] == "column_output_bias":
|
||||
@ -697,6 +705,56 @@ class TapasPooler(nn.Module):
|
||||
return pooled_output
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Tapas
|
||||
class TapasPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Tapas
|
||||
class TapasLMPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transform = TapasPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Tapas
|
||||
class TapasOnlyMLMHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.predictions = TapasLMPredictionHead(config)
|
||||
|
||||
def forward(self, sequence_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class TapasPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -947,15 +1005,15 @@ class TapasForMaskedLM(TapasPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.tapas = TapasModel(config, add_pooling_layer=False)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.cls = TapasOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
def set_output_embeddings(self, word_embeddings):
|
||||
self.lm_head = word_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1020,7 +1078,7 @@ class TapasForMaskedLM(TapasPreTrainedModel):
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
|
@ -3406,6 +3406,10 @@ class TapasPreTrainedModel:
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def load_tf_weights_in_tapas(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_tapas, ["torch"])
|
||||
|
||||
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user