mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Add Doc Tests for Reformer PyTorch (#16565)
* start working * fix: ReformerForQA doctest * fix: ReformerModelWithLMHead doctest * fix: ReformerModelForSC doctest * fix: ReformerModelForMLM doctest * add: documentation_tests.txt * make fixup * change: ReformerModelForSC doctest * change: checkpoint
This commit is contained in:
parent
d7f7f29f29
commit
1bac40db8a
@ -40,6 +40,7 @@ from ...utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_reformer import ReformerConfig
|
||||
|
||||
@ -2311,12 +2312,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -2335,6 +2331,44 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
|
||||
the loss is only computed for the tokens with labels
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import ReformerTokenizer, ReformerForMaskedLM
|
||||
|
||||
>>> tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer")
|
||||
>>> model = ReformerForMaskedLM.from_pretrained("hf-internal-testing/tiny-random-reformer")
|
||||
|
||||
>>> # add mask_token
|
||||
>>> tokenizer.add_special_tokens({"mask_token": "[MASK]"}) # doctest: +IGNORE_RESULT
|
||||
>>> inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> # retrieve index of [MASK]
|
||||
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
||||
|
||||
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
||||
>>> tokenizer.decode(predicted_token_id)
|
||||
'it'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
||||
>>> # mask labels of non-[MASK] tokens
|
||||
>>> labels = torch.where(
|
||||
... inputs.input_ids == tokenizer.mask_token_id, labels[:, : inputs["input_ids"].shape[-1]], -100
|
||||
... )
|
||||
|
||||
>>> outputs = model(**inputs, labels=labels)
|
||||
>>> round(outputs.loss.item(), 2)
|
||||
7.09
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -2393,12 +2427,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -2417,6 +2446,79 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Example of single-label classification:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import ReformerTokenizer, ReformerForSequenceClassification
|
||||
|
||||
>>> tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer")
|
||||
>>> model = ReformerForSequenceClassification.from_pretrained("hf-internal-testing/tiny-random-reformer")
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> predicted_class_id = logits.argmax().item()
|
||||
>>> model.config.id2label[predicted_class_id]
|
||||
'LABEL_1'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> model = ReformerForSequenceClassification.from_pretrained(
|
||||
... "hf-internal-testing/tiny-random-reformer", num_labels=num_labels
|
||||
... )
|
||||
|
||||
>>> labels = torch.tensor(1)
|
||||
>>> loss = model(**inputs, labels=labels).loss
|
||||
>>> round(loss.item(), 2)
|
||||
0.69
|
||||
```
|
||||
|
||||
Example of multi-label classification:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import ReformerTokenizer, ReformerForSequenceClassification
|
||||
|
||||
>>> tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer")
|
||||
>>> model = ReformerForSequenceClassification.from_pretrained(
|
||||
... "hf-internal-testing/tiny-random-reformer", problem_type="multi_label_classification"
|
||||
... )
|
||||
|
||||
>>> # add pad_token
|
||||
>>> tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # doctest: +IGNORE_RESULT
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", max_length=100, padding="max_length", return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... logits = model(**inputs).logits
|
||||
|
||||
>>> predicted_class_id = logits.argmax().item()
|
||||
>>> model.config.id2label[predicted_class_id]
|
||||
'LABEL_1'
|
||||
```
|
||||
|
||||
```python
|
||||
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> model = ReformerForSequenceClassification.from_pretrained(
|
||||
... "hf-internal-testing/tiny-random-reformer", num_labels=num_labels
|
||||
... )
|
||||
>>> model.train() # doctest: +IGNORE_RESULT
|
||||
|
||||
>>> num_labels = len(model.config.id2label)
|
||||
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
||||
... torch.float
|
||||
... )
|
||||
>>> loss = model(**inputs, labels=labels).loss
|
||||
>>> loss.backward() # doctest: +IGNORE_RESULT
|
||||
```
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@ -2514,9 +2616,11 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="hf-internal-testing/tiny-random-reformer",
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="''",
|
||||
expected_loss=3.28,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -24,6 +24,7 @@ src/transformers/models/mobilebert/modeling_tf_mobilebert.py
|
||||
src/transformers/models/pegasus/modeling_pegasus.py
|
||||
src/transformers/models/plbart/modeling_plbart.py
|
||||
src/transformers/models/poolformer/modeling_poolformer.py
|
||||
src/transformers/models/reformer/modeling_reformer.py
|
||||
src/transformers/models/resnet/modeling_resnet.py
|
||||
src/transformers/models/roberta/modeling_roberta.py
|
||||
src/transformers/models/roberta/modeling_tf_roberta.py
|
||||
|
Loading…
Reference in New Issue
Block a user