Fix doctest (#20843)

* fix doc for generation, dinat, nat and prelayernorm

* style

* update

* fix cpies

* use auto config and auto tokenizer

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* als modify roberta and the depending models

Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
Arthur 2022-12-21 16:34:31 +01:00 committed by GitHub
parent aaa6296de2
commit 76d02feadb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 21 deletions

View File

@ -2264,6 +2264,7 @@ class GenerationMixin:
>>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
>>> model.config.pad_token_id = model.config.eos_token_id
>>> model.generation_config.pad_token_id = model.config.eos_token_id
>>> input_prompt = "Today is a beautiful day, and"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

View File

@ -1502,11 +1502,11 @@ class CamembertForCausalLM(CamembertPreTrainedModel):
Example:
```python
>>> from transformers import CamembertTokenizer, CamembertForCausalLM, CamembertConfig
>>> from transformers import AutoTokenizer, CamembertForCausalLM, AutoConfig
>>> import torch
>>> tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
>>> config = CamembertConfig.from_pretrained("camembert-base")
>>> tokenizer = AutoTokenizer.from_pretrained("camembert-base")
>>> config = AutoConfig.from_pretrained("camembert-base")
>>> config.is_decoder = True
>>> model = CamembertForCausalLM.from_pretrained("camembert-base", config=config)

View File

@ -943,7 +943,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
@ -952,7 +952,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
[1, 512, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (

View File

@ -921,7 +921,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "shi-labs/nat-mini-in1k-2240", out_features=["stage1", "stage2", "stage3", "stage4"]
... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
@ -930,7 +930,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 2048, 7, 7]
[1, 512, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (

View File

@ -956,11 +956,11 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
Example:
```python
>>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig
>>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
>>> import torch
>>> tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
>>> config = RobertaConfig.from_pretrained("roberta-base")
>>> tokenizer = AutoTokenizer.from_pretrained("roberta-base")
>>> config = AutoConfig.from_pretrained("roberta-base")
>>> config.is_decoder = True
>>> model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)

View File

@ -885,7 +885,7 @@ class RobertaPreLayerNormModel(RobertaPreLayerNormPreTrainedModel):
"""RoBERTa-PreLayerNorm Model with a `language modeling` head on top for CLM fine-tuning.""",
ROBERTA_PRELAYERNORM_START_DOCSTRING,
)
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm
# Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer
class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
@ -963,15 +963,11 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel):
Example:
```python
>>> from transformers import (
... RobertaPreLayerNormTokenizer,
... RobertaPreLayerNormForCausalLM,
... RobertaPreLayerNormConfig,
... )
>>> from transformers import AutoTokenizer, RobertaPreLayerNormForCausalLM, AutoConfig
>>> import torch
>>> tokenizer = RobertaPreLayerNormTokenizer.from_pretrained("andreasmadsen/efficient_mlm_m0.40")
>>> config = RobertaPreLayerNormConfig.from_pretrained("andreasmadsen/efficient_mlm_m0.40")
>>> tokenizer = AutoTokenizer.from_pretrained("andreasmadsen/efficient_mlm_m0.40")
>>> config = AutoConfig.from_pretrained("andreasmadsen/efficient_mlm_m0.40")
>>> config.is_decoder = True
>>> model = RobertaPreLayerNormForCausalLM.from_pretrained("andreasmadsen/efficient_mlm_m0.40", config=config)

View File

@ -960,11 +960,11 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel):
Example:
```python
>>> from transformers import XLMRobertaTokenizer, XLMRobertaForCausalLM, XLMRobertaConfig
>>> from transformers import AutoTokenizer, XLMRobertaForCausalLM, AutoConfig
>>> import torch
>>> tokenizer = XLMRobertaTokenizer.from_pretrained("roberta-base")
>>> config = XLMRobertaConfig.from_pretrained("roberta-base")
>>> tokenizer = AutoTokenizer.from_pretrained("roberta-base")
>>> config = AutoConfig.from_pretrained("roberta-base")
>>> config.is_decoder = True
>>> model = XLMRobertaForCausalLM.from_pretrained("roberta-base", config=config)