mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix the tests for Electra (#6284)
* Fix the tests for Electra * Apply style
This commit is contained in:
parent
6ba540b747
commit
0e36e51515
@ -857,7 +857,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.electra = ElectraModel(config)
|
||||
self.summary = SequenceSummary(config)
|
||||
self.sequence_summary = SequenceSummary(config)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
|
||||
self.init_weights()
|
||||
@ -915,7 +915,7 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
|
||||
sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
pooled_output = self.summary(sequence_output)
|
||||
pooled_output = self.sequence_summary(sequence_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, num_choices)
|
||||
|
||||
|
@ -63,6 +63,7 @@ class TFElectraModelTester:
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
self.embedding_size = 128
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -194,7 +195,14 @@ class TFElectraModelTester:
|
||||
class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(TFElectraModel, TFElectraForMaskedLM, TFElectraForPreTraining, TFElectraForTokenClassification,)
|
||||
(
|
||||
TFElectraModel,
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForPreTraining,
|
||||
TFElectraForTokenClassification,
|
||||
TFElectraForMultipleChoice,
|
||||
TFElectraForSequenceClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user