transformers/docs/source/en/model_doc/electra.md
Surya Garikipati 8dd0a2b89c
Update model card for electra (#37063)
* Update ELECTRA model card with new format

* Update ELECTRA model card with new format

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/electra.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* close hfoption block

---------

Co-authored-by: Wun0 <f20191221@hyderabad.bits-pilani.ac.in>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-04-03 10:45:35 -07:00

10 KiB

PyTorch TensorFlow Flax SDPA

ELECTRA

ELECTRA modifies the pretraining objective of traditional masked language models like BERT. Instead of just masking tokens and asking the model to predict them, ELECTRA trains two models, a generator and a discriminator. The generator replaces some tokens with plausible alternatives and the discriminator (the model you'll actually use) learns to detect which tokens are original and which were replaced. This training approach is very efficient and scales to larger models while using considerably less compute.

This approach is super efficient because ELECTRA learns from every single token in the input, not just the masked ones. That's why even the small ELECTRA models can match or outperform much larger models while using way less computing resources.

You can find all the original ELECTRA checkpoints under the ELECTRA release.

Tip

Click on the right sidebar for more examples of how to use ELECTRA for different language tasks like sequence classification, token classification, and question answering.

The example below demonstrates how to classify text with [Pipeline] or the [AutoModel] class.

import torch
from transformers import pipeline

classifier = pipeline(
    task="text-classification", 
    model="bhadresh-savani/electra-base-emotion", 
    torch_dtype=torch.float16, 
    device=0
)
classifier("This restaurant has amazing food!")
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained(
    "bhadresh-savani/electra-base-emotion",
)
model = AutoModelForSequenceClassification.from_pretrained(
    "bhadresh-savani/electra-base-emotion", 
    torch_dtype=torch.float16
)
inputs = tokenizer("ELECTRA is more efficient than BERT", return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = logits.argmax(dim=-1).item()
    predicted_label = model.config.id2label[predicted_class_id]
print(f"Predicted label: {predicted_label}")
echo -e "This restaurant has amazing food." | transformers-cli run --task text-classification --model bhadresh-savani/electra-base-emotion --device 0

Notes

  • ELECTRA consists of two transformer models, a generator (G) and a discriminator (D). For most downstream tasks, use the discriminator model (as indicated by *-discriminator in the name) rather than the generator.

  • ELECTRA comes in three sizes: small (14M parameters), base (110M parameters), and large (335M parameters).

  • ELECTRA can use a smaller embedding size than the hidden size for efficiency. When embedding_size is smaller than hidden_size in the configuration, a projection layer connects them.

  • When using batched inputs with padding, make sure to use attention masks to prevent the model from attending to padding tokens.

    # Example of properly handling padding with attention masks
    inputs = tokenizer(["Short text", "This is a much longer text that needs padding"], 
                    padding=True, 
                    return_tensors="pt")
    outputs = model(**inputs)  # automatically uses the attention_mask
    
  • When using the discriminator for a downstream task, you can load it into any of the ELECTRA model classes ([ElectraForSequenceClassification], [ElectraForTokenClassification], etc.).

ElectraConfig

autodoc ElectraConfig

ElectraTokenizer

autodoc ElectraTokenizer

ElectraTokenizerFast

autodoc ElectraTokenizerFast

Electra specific outputs

autodoc models.electra.modeling_electra.ElectraForPreTrainingOutput

autodoc models.electra.modeling_tf_electra.TFElectraForPreTrainingOutput

ElectraModel

autodoc ElectraModel - forward

ElectraForPreTraining

autodoc ElectraForPreTraining - forward

ElectraForCausalLM

autodoc ElectraForCausalLM - forward

ElectraForMaskedLM

autodoc ElectraForMaskedLM - forward

ElectraForSequenceClassification

autodoc ElectraForSequenceClassification - forward

ElectraForMultipleChoice

autodoc ElectraForMultipleChoice - forward

ElectraForTokenClassification

autodoc ElectraForTokenClassification - forward

ElectraForQuestionAnswering

autodoc ElectraForQuestionAnswering - forward

TFElectraModel

autodoc TFElectraModel - call

TFElectraForPreTraining

autodoc TFElectraForPreTraining - call

TFElectraForMaskedLM

autodoc TFElectraForMaskedLM - call

TFElectraForSequenceClassification

autodoc TFElectraForSequenceClassification - call

TFElectraForMultipleChoice

autodoc TFElectraForMultipleChoice - call

TFElectraForTokenClassification

autodoc TFElectraForTokenClassification - call

TFElectraForQuestionAnswering

autodoc TFElectraForQuestionAnswering - call

FlaxElectraModel

autodoc FlaxElectraModel - call

FlaxElectraForPreTraining

autodoc FlaxElectraForPreTraining - call

FlaxElectraForCausalLM

autodoc FlaxElectraForCausalLM - call

FlaxElectraForMaskedLM

autodoc FlaxElectraForMaskedLM - call

FlaxElectraForSequenceClassification

autodoc FlaxElectraForSequenceClassification - call

FlaxElectraForMultipleChoice

autodoc FlaxElectraForMultipleChoice - call

FlaxElectraForTokenClassification

autodoc FlaxElectraForTokenClassification - call

FlaxElectraForQuestionAnswering

autodoc FlaxElectraForQuestionAnswering - call