Fix the tests for Electra (#6284)

* Fix the tests for Electra

* Apply style
This commit is contained in:
Julien Plu 2020-08-07 15:30:57 +02:00 committed by GitHub
parent 6ba540b747
commit 0e36e51515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View File

@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
super().__init__(config)
self.electra = ElectraModel(config)
self.summary = SequenceSummary(config)
self.sequence_summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
sequence_output = discriminator_hidden_states[0]
pooled_output = self.summary(sequence_output)
pooled_output = self.sequence_summary(sequence_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)

View File

@ -63,6 +63,7 @@ class TFElectraModelTester:
self.num_labels = 3
self.num_choices = 4
self.scope = None
self.embedding_size = 128
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@ -194,7 +195,14 @@ class TFElectraModelTester:
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(TFElectraModel, TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification,)
(
TFElectraModel,
TFElectraForMaskedLM,
TFElectraForPreTraining,
TFElectraForTokenClassification,
TFElectraForMultipleChoice,
TFElectraForSequenceClassification,
)
if is_tf_available()
else ()
)