Add ElectraForCausalLM -> Enable Electra encoder-decoder model (#14729)

* Add ElectraForCausalLM and cover some basic tests & need to fix a few tests

* Fix bugs

* make style

* make fix-copies

* Update doc

* Change docstring to markdown format

* Remove redundant update_keys_to_ignore
This commit is contained in:
Daniel Stancl 2021-12-27 12:37:52 +01:00 committed by GitHub
parent b058490ceb
commit 501307b58b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 323 additions and 34 deletions

View File

@ -83,6 +83,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
[[autodoc]] ElectraForPreTraining
- forward
## ElectraForCausalLM
[[autodoc]] ElectraForCausalLM
- forward
## ElectraForMaskedLM
[[autodoc]] ElectraForMaskedLM

View File

@ -885,6 +885,7 @@ if is_torch_available():
_import_structure["models.electra"].extend(
[
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
"ElectraForCausalLM",
"ElectraForMaskedLM",
"ElectraForMultipleChoice",
"ElectraForPreTraining",
@ -2830,6 +2831,7 @@ if TYPE_CHECKING:
)
from .models.electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
ElectraForCausalLM,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,

View File

@ -218,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("transfo-xl", "TransfoXLLMHeadModel"),
("xlnet", "XLNetLMHeadModel"),
("xlm", "XLMWithLMHeadModel"),
("electra", "ElectraForCausalLM"),
("ctrl", "CTRLLMHeadModel"),
("reformer", "ReformerModelWithLMHead"),
("bert-generation", "BertGenerationDecoder"),

View File

@ -32,6 +32,7 @@ if is_tokenizers_available():
if is_torch_available():
_import_structure["modeling_electra"] = [
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
"ElectraForCausalLM",
"ElectraForMaskedLM",
"ElectraForMultipleChoice",
"ElectraForPreTraining",
@ -79,6 +80,7 @@ if TYPE_CHECKING:
if is_torch_available():
from .modeling_electra import (
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
ElectraForCausalLM,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,

View File

@ -36,6 +36,7 @@ from ...file_utils import (
from ...modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
@ -846,6 +847,10 @@ class ElectraModel(ElectraPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@ -868,6 +873,9 @@ class ElectraModel(ElectraPreTrainedModel):
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
@ -879,10 +887,26 @@ class ElectraModel(ElectraPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
if hasattr(self, "embeddings_project"):
@ -892,6 +916,10 @@ class ElectraModel(ElectraPreTrainedModel):
hidden_states,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
@ -969,14 +997,14 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
discriminator_hidden_states = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = discriminator_hidden_states[0]
@ -1075,14 +1103,14 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
discriminator_hidden_states = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
discriminator_sequence_output = discriminator_hidden_states[0]
@ -1166,14 +1194,14 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
generator_hidden_states = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
generator_sequence_output = generator_hidden_states[0]
@ -1247,14 +1275,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
discriminator_hidden_states = self.electra(
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
return_dict,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
discriminator_sequence_output = discriminator_hidden_states[0]
@ -1481,3 +1509,152 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
hidden_states=discriminator_hidden_states.hidden_states,
attentions=discriminator_hidden_states.attentions,
)
@add_start_docstrings(
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning. """, ELECTRA_START_DOCSTRING
)
class ElectraForCausalLM(ElectraPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if not config.is_decoder:
logger.warning("If you want to use `ElectraLMHeadModel` as a standalone, add `is_decoder=True.`")
self.electra = ElectraModel(config)
self.generator_predictions = ElectraGeneratorPredictions(config)
self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
self.init_weights()
def get_output_embeddings(self):
return self.generator_lm_head
def set_output_embeddings(self, new_embeddings):
self.generator_lm_head = new_embeddings
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
decoding (see `past_key_values`).
Returns:
Example:
```python
>>> from transformers import ElectraTokenizer, ElectraForCausalLM, ElectraConfig
>>> import torch
>>> tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-generator")
>>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
>>> config.is_decoder = True
>>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.electra(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output))
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

View File

@ -1924,6 +1924,18 @@ class DPRReader:
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ElectraForCausalLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def forward(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ElectraForMaskedLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@ -21,7 +21,7 @@ from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_torch_available():
@ -29,6 +29,7 @@ if is_torch_available():
from transformers import (
MODEL_FOR_PRETRAINING_MAPPING,
ElectraForCausalLM,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
@ -117,6 +118,34 @@ class ElectraModelTester:
initializer_range=self.initializer_range,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
_,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_electra_model(
self,
config,
@ -136,6 +165,38 @@ class ElectraModelTester:
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_electra_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = ElectraModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_electra_for_masked_lm(
self,
config,
@ -153,6 +214,24 @@ class ElectraModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_electra_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = ElectraForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_electra_for_token_classification(
self,
config,
@ -281,6 +360,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
ElectraModel,
ElectraForPreTraining,
ElectraForMaskedLM,
ElectraForCausalLM,
ElectraForMultipleChoice,
ElectraForTokenClassification,
ElectraForSequenceClassification,
@ -289,6 +369,8 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
@ -314,6 +396,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_model(*config_and_inputs)
def test_electra_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_electra_model_as_decoder(*config_and_inputs)
def test_electra_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
@ -350,6 +436,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
model = ElectraModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_electra_for_causal_lm(*config_and_inputs)
@require_torch
class ElectraModelIntegrationTest(unittest.TestCase):