mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add ElectraForCausalLM
-> Enable Electra encoder-decoder model (#14729)
* Add ElectraForCausalLM and cover some basic tests & need to fix a few tests * Fix bugs * make style * make fix-copies * Update doc * Change docstring to markdown format * Remove redundant update_keys_to_ignore
This commit is contained in:
parent
b058490ceb
commit
501307b58b
@ -83,6 +83,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
|
||||
[[autodoc]] ElectraForPreTraining
|
||||
- forward
|
||||
|
||||
## ElectraForCausalLM
|
||||
|
||||
[[autodoc]] ElectraForCausalLM
|
||||
- forward
|
||||
|
||||
## ElectraForMaskedLM
|
||||
|
||||
[[autodoc]] ElectraForMaskedLM
|
||||
|
@ -885,6 +885,7 @@ if is_torch_available():
|
||||
_import_structure["models.electra"].extend(
|
||||
[
|
||||
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"ElectraForCausalLM",
|
||||
"ElectraForMaskedLM",
|
||||
"ElectraForMultipleChoice",
|
||||
"ElectraForPreTraining",
|
||||
@ -2830,6 +2831,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.electra import (
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ElectraForCausalLM,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForPreTraining,
|
||||
|
@ -218,6 +218,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("xlnet", "XLNetLMHeadModel"),
|
||||
("xlm", "XLMWithLMHeadModel"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("reformer", "ReformerModelWithLMHead"),
|
||||
("bert-generation", "BertGenerationDecoder"),
|
||||
|
@ -32,6 +32,7 @@ if is_tokenizers_available():
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_electra"] = [
|
||||
"ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"ElectraForCausalLM",
|
||||
"ElectraForMaskedLM",
|
||||
"ElectraForMultipleChoice",
|
||||
"ElectraForPreTraining",
|
||||
@ -79,6 +80,7 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_electra import (
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ElectraForCausalLM,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForPreTraining,
|
||||
|
@ -36,6 +36,7 @@ from ...file_utils import (
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithCrossAttentions,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MaskedLMOutput,
|
||||
MultipleChoiceModelOutput,
|
||||
QuestionAnsweringModelOutput,
|
||||
@ -846,6 +847,10 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@ -868,6 +873,9 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(input_shape, device=device)
|
||||
if token_type_ids is None:
|
||||
@ -879,10 +887,26 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if hasattr(self, "embeddings_project"):
|
||||
@ -892,6 +916,10 @@ class ElectraModel(ElectraPreTrainedModel):
|
||||
hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@ -969,14 +997,14 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = discriminator_hidden_states[0]
|
||||
@ -1075,14 +1103,14 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
@ -1166,14 +1194,14 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
|
||||
|
||||
generator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
generator_sequence_output = generator_hidden_states[0]
|
||||
|
||||
@ -1247,14 +1275,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
|
||||
|
||||
discriminator_hidden_states = self.electra(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
head_mask,
|
||||
inputs_embeds,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
@ -1481,3 +1509,152 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
hidden_states=discriminator_hidden_states.hidden_states,
|
||||
attentions=discriminator_hidden_states.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning. """, ELECTRA_START_DOCSTRING
|
||||
)
|
||||
class ElectraForCausalLM(ElectraPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
if not config.is_decoder:
|
||||
logger.warning("If you want to use `ElectraLMHeadModel` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.generator_predictions = ElectraGeneratorPredictions(config)
|
||||
self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.generator_lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.generator_lm_head = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
|
||||
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
|
||||
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up
|
||||
decoding (see `past_key_values`).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import ElectraTokenizer, ElectraForCausalLM, ElectraConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-generator")
|
||||
>>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
|
||||
>>> config.is_decoder = True
|
||||
>>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if labels is not None:
|
||||
use_cache = False
|
||||
|
||||
outputs = self.electra(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output))
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (prediction_scores,) + outputs[1:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=lm_loss,
|
||||
logits=prediction_scores,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
|
||||
|
||||
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||
return reordered_past
|
||||
|
@ -1924,6 +1924,18 @@ class DPRReader:
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class ElectraForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ElectraForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
@ -21,7 +21,7 @@ from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -29,6 +29,7 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
ElectraForCausalLM,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForPreTraining,
|
||||
@ -117,6 +118,34 @@ class ElectraModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
_,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_electra_model(
|
||||
self,
|
||||
config,
|
||||
@ -136,6 +165,38 @@ class ElectraModelTester:
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_electra_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = ElectraModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_electra_for_masked_lm(
|
||||
self,
|
||||
config,
|
||||
@ -153,6 +214,24 @@ class ElectraModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_electra_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = ElectraForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_electra_for_token_classification(
|
||||
self,
|
||||
config,
|
||||
@ -281,6 +360,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
ElectraModel,
|
||||
ElectraForPreTraining,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForCausalLM,
|
||||
ElectraForMultipleChoice,
|
||||
ElectraForTokenClassification,
|
||||
ElectraForSequenceClassification,
|
||||
@ -289,6 +369,8 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (ElectraForCausalLM,) if is_torch_available() else ()
|
||||
|
||||
fx_ready_model_classes = all_model_classes
|
||||
fx_dynamic_ready_model_classes = all_model_classes
|
||||
|
||||
@ -314,6 +396,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_electra_model(*config_and_inputs)
|
||||
|
||||
def test_electra_model_as_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_electra_model_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_electra_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
@ -350,6 +436,10 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = ElectraModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_electra_for_causal_lm(*config_and_inputs)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ElectraModelIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user