# XLM-RoBERTa
[XLM-RoBERTa](https://huggingface.co/papers/1911.02116) is a large multilingual masked language model trained on 2.5TB of filtered CommonCrawl data across 100 languages. It shows that scaling the model provides strong performance gains on high-resource and low-resource languages. The model uses the [RoBERTa](./roberta) pretraining objectives on the [XLM](./xlm) model.
You can find all the original XLM-RoBERTa checkpoints under the [Facebook AI community](https://huggingface.co/FacebookAI) organization.
> [!TIP]
> Click on the XLM-RoBERTa models in the right sidebar for more examples of how to apply XLM-RoBERTa to different cross-lingual tasks like classification, translation, and question answering.
The example below demonstrates how to predict the `` token with [`Pipeline`], [`AutoModel`], and from the command line.
```python
import torch
from transformers import pipeline
pipeline = pipeline(
task="fill-mask",
model="FacebookAI/xlm-roberta-base",
torch_dtype=torch.float16,
device=0
)
# Example in French
pipeline("Bonjour, je suis un modèle .")
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained(
"FacebookAI/xlm-roberta-base"
)
model = AutoModelForMaskedLM.from_pretrained(
"FacebookAI/xlm-roberta-base",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
# Prepare input
inputs = tokenizer("Bonjour, je suis un modèle .", return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print(f"The predicted token is: {predicted_token}")
```
```bash
echo -e "Plants create through a process known as photosynthesis." | transformers-cli run --task fill-mask --model FacebookAI/xlm-roberta-base --device 0
```
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [quantization guide](../quantization) overview for more available quantization backends.
The example below uses [bitsandbytes](../quantization/bitsandbytes) the quantive the weights to 4 bits
```python
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
bnb_4bit_quant_type="nf4", # or "fp4" for float 4-bit quantization
bnb_4bit_use_double_quant=True, # use double quantization for better performance
)
tokenizer = AutoTokenizer.from_pretrained("facebook/xlm-roberta-large")
model = AutoModelForMaskedLM.from_pretrained(
"facebook/xlm-roberta-large",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="flash_attention_2",
quantization_config=quantization_config
)
inputs = tokenizer("Bonjour, je suis un modèle .", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## Notes
- Unlike some XLM models, XLM-RoBERTa doesn't require `lang` tensors to understand what language is being used. It automatically determines the language from the input IDs
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with XLM-RoBERTa. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
- A blog post on how to [finetune XLM RoBERTa for multiclass classification with Habana Gaudi on AWS](https://www.philschmid.de/habana-distributed-training)
- [`XLMRobertaForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb).
- [`TFXLMRobertaForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification-tf.ipynb).
- [`FlaxXLMRobertaForSequenceClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb).
- [Text classification](https://huggingface.co/docs/transformers/tasks/sequence_classification) chapter of the 🤗 Hugging Face Task Guides.
- [Text classification task guide](../tasks/sequence_classification)
- [`XLMRobertaForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb).
- [`TFXLMRobertaForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/token-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification-tf.ipynb).
- [`FlaxXLMRobertaForTokenClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/token-classification).
- [Token classification](https://huggingface.co/course/chapter7/2?fw=pt) chapter of the 🤗 Hugging Face Course.
- [Token classification task guide](../tasks/token_classification)
- [`XLMRobertaForCausalLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
- [Causal language modeling](https://huggingface.co/docs/transformers/tasks/language_modeling) chapter of the 🤗 Hugging Face Task Guides.
- [Causal language modeling task guide](../tasks/language_modeling)
- [`XLMRobertaForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#robertabertdistilbert-and-masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
- [`TFXLMRobertaForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_mlmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
- [`FlaxXLMRobertaForMaskedLM`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#masked-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb).
- [Masked language modeling](https://huggingface.co/course/chapter7/3?fw=pt) chapter of the 🤗 Hugging Face Course.
- [Masked language modeling](../tasks/masked_language_modeling)
- [`XLMRobertaForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering.ipynb).
- [`TFXLMRobertaForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/question-answering) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/question_answering-tf.ipynb).
- [`FlaxXLMRobertaForQuestionAnswering`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/flax/question-answering).
- [Question answering](https://huggingface.co/course/chapter7/7?fw=pt) chapter of the 🤗 Hugging Face Course.
- [Question answering task guide](../tasks/question_answering)
**Multiple choice**
- [`XLMRobertaForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice.ipynb).
- [`TFXLMRobertaForMultipleChoice`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/multiple-choice) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/multiple_choice-tf.ipynb).
- [Multiple choice task guide](../tasks/multiple_choice)
🚀 Deploy
- A blog post on how to [Deploy Serverless XLM RoBERTa on AWS Lambda](https://www.philschmid.de/multilingual-serverless-xlm-roberta-with-huggingface).
This implementation is the same as RoBERTa. Refer to the [documentation of RoBERTa](roberta) for usage examples as well as the information relative to the inputs and outputs.
## XLMRobertaConfig
[[autodoc]] XLMRobertaConfig
## XLMRobertaTokenizer
[[autodoc]] XLMRobertaTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
## XLMRobertaTokenizerFast
[[autodoc]] XLMRobertaTokenizerFast
## XLMRobertaModel
[[autodoc]] XLMRobertaModel
- forward
## XLMRobertaForCausalLM
[[autodoc]] XLMRobertaForCausalLM
- forward
## XLMRobertaForMaskedLM
[[autodoc]] XLMRobertaForMaskedLM
- forward
## XLMRobertaForSequenceClassification
[[autodoc]] XLMRobertaForSequenceClassification
- forward
## XLMRobertaForMultipleChoice
[[autodoc]] XLMRobertaForMultipleChoice
- forward
## XLMRobertaForTokenClassification
[[autodoc]] XLMRobertaForTokenClassification
- forward
## XLMRobertaForQuestionAnswering
[[autodoc]] XLMRobertaForQuestionAnswering
- forward
## TFXLMRobertaModel
[[autodoc]] TFXLMRobertaModel
- call
## TFXLMRobertaForCausalLM
[[autodoc]] TFXLMRobertaForCausalLM
- call
## TFXLMRobertaForMaskedLM
[[autodoc]] TFXLMRobertaForMaskedLM
- call
## TFXLMRobertaForSequenceClassification
[[autodoc]] TFXLMRobertaForSequenceClassification
- call
## TFXLMRobertaForMultipleChoice
[[autodoc]] TFXLMRobertaForMultipleChoice
- call
## TFXLMRobertaForTokenClassification
[[autodoc]] TFXLMRobertaForTokenClassification
- call
## TFXLMRobertaForQuestionAnswering
[[autodoc]] TFXLMRobertaForQuestionAnswering
- call
## FlaxXLMRobertaModel
[[autodoc]] FlaxXLMRobertaModel
- __call__
## FlaxXLMRobertaForCausalLM
[[autodoc]] FlaxXLMRobertaForCausalLM
- __call__
## FlaxXLMRobertaForMaskedLM
[[autodoc]] FlaxXLMRobertaForMaskedLM
- __call__
## FlaxXLMRobertaForSequenceClassification
[[autodoc]] FlaxXLMRobertaForSequenceClassification
- __call__
## FlaxXLMRobertaForMultipleChoice
[[autodoc]] FlaxXLMRobertaForMultipleChoice
- __call__
## FlaxXLMRobertaForTokenClassification
[[autodoc]] FlaxXLMRobertaForTokenClassification
- __call__
## FlaxXLMRobertaForQuestionAnswering
[[autodoc]] FlaxXLMRobertaForQuestionAnswering
- __call__