Fix doc formatting in forward passes & modular (#36243)

* fix indentation issues + modular without magic keyword

* style

* Update doc.py

* style

* Fix all decorators indentation

* all models

* style

* style

* Update doc.py

* fix

* general fix

* style
This commit is contained in:
Cyril Vallez 2025-02-25 11:09:01 +01:00 committed by GitHub
parent 92abc0dae8
commit da4ab2a1b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
67 changed files with 83 additions and 90 deletions

View File

@ -349,7 +349,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
num_logits_to_keep: int = 0,
) -> Union[Tuple, NewTaskModelCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1193,7 +1193,6 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -1458,7 +1457,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).

View File

@ -1437,7 +1437,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).

View File

@ -1495,7 +1495,6 @@ class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1205,7 +1205,6 @@ class BambaForCausalLM(LlamaForCausalLM):
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1554,7 +1554,6 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -833,7 +833,6 @@ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -321,7 +321,6 @@ class CohereForCausalLM(LlamaForCausalLM):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -834,7 +834,6 @@ class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1283,9 +1283,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""Forward function for causal language modeling.
Args:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -716,7 +716,6 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1070,7 +1070,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1650,7 +1650,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -1878,7 +1877,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1077,7 +1077,6 @@ class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig")
def forward(**super_kwargs):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -1186,7 +1185,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -803,7 +803,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -483,7 +483,6 @@ class GemmaModel(LlamaModel):
class GemmaForCausalLM(LlamaForCausalLM):
def forward(**super_kwargs):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -841,7 +841,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -859,9 +858,9 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> from transformers import AutoTokenizer, Gemma2ForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
>>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
>>> prompt = "What is your favorite condiment?"

View File

@ -591,10 +591,26 @@ class Gemma2ForCausalLM(GemmaForCausalLM):
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Gemma2ForCausalLM
>>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
>>> prompt = "What is your favorite condiment?"

View File

@ -812,7 +812,6 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -769,7 +769,6 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, GotOcr2CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -848,7 +848,6 @@ class GotOcr2ForConditionalGeneration(LlavaForConditionalGeneration):
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -815,7 +815,6 @@ class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1287,7 +1287,6 @@ class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1313,7 +1313,6 @@ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMix
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -799,7 +799,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1559,7 +1559,6 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1687,7 +1687,6 @@ class TFIdeficsForVisionText2Text(TFPreTrainedModel, TFCausalLanguageModelingLos
training=False,
) -> Union[TFIdeficsCausalLMOutputWithPast, Tuple[tf.Tensor]]:
r"""
Args:
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1537,7 +1537,6 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin)
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics2ForConditionalGeneration`).

View File

@ -1121,7 +1121,6 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin)
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, Idefics3CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).

View File

@ -1456,7 +1456,6 @@ class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1299,7 +1299,6 @@ class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -801,7 +801,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -348,7 +348,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
**lm_kwargs,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -561,7 +561,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
**lm_kwargs,
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -601,7 +601,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
**lm_kwargs,
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`LlavaNextVideoVideoProcessor.__call__`] for details. [`LlavaProcessor`] uses

View File

@ -360,7 +360,6 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration):
**lm_kwargs,
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, image_size, image_size)):
The tensors corresponding to the input videos. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`LlavaNextVideoVideoProcessor.__call__`] for details. [`LlavaProcessor`] uses

View File

@ -623,7 +623,6 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
**lm_kwargs,
) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -802,7 +802,6 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -849,11 +849,10 @@ class TFMistralForCausalLM(TFMistralPreTrainedModel, TFCausalLanguageModelingLos
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFCausalLMOutputWithPast]:
r"""
Args:
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@ -975,11 +974,10 @@ class TFMistralForSequenceClassification(TFMistralPreTrainedModel, TFSequenceCla
return_dict: Optional[bool] = None,
) -> Union[Tuple, TFSequenceClassifierOutputWithPast]:
r"""
Args:
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
transformer_outputs = self.model(

View File

@ -1022,7 +1022,6 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -480,7 +480,6 @@ class MixtralForCausalLM(MistralForCausalLM):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1901,7 +1901,6 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@ -2048,7 +2047,6 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
logits_to_keep: Union[int, torch.Tensor] = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1813,7 +1813,6 @@ class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, MoshiCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1047,7 +1047,6 @@ class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -777,7 +777,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -778,7 +778,6 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1206,7 +1206,6 @@ class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -438,7 +438,6 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
**lm_kwargs,
) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -852,7 +852,6 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -775,7 +775,6 @@ class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -877,7 +877,6 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1388,7 +1388,6 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -815,7 +815,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1742,7 +1742,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -608,7 +608,6 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1112,7 +1112,6 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
return_dict: Optional[bool] = None,
) -> Union[Tuple, Qwen2AudioCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1272,7 +1272,6 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1619,7 +1619,6 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -821,7 +821,6 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, CausalLMOutput]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1109,7 +1109,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -798,7 +798,6 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -383,7 +383,6 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
**lm_kwargs,
) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -323,7 +323,6 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
**lm_kwargs,
) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1228,7 +1228,6 @@ class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -1665,7 +1665,6 @@ class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

View File

@ -16,10 +16,23 @@ Doc utilities: Utilities related to documentation
"""
import functools
import inspect
import re
import textwrap
import types
def get_docstring_indentation_level(func):
"""Return the indentation level of the start of the docstring of a class or function (or method)."""
# We assume classes are always defined in the global scope
if inspect.isclass(func):
return 4
source = inspect.getsource(func)
first_line = source.splitlines()[0]
function_def_level = len(first_line) - len(first_line.lstrip())
return 4 + function_def_level
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
@ -30,10 +43,8 @@ def add_start_docstrings(*docstr):
def add_start_docstrings_to_model_forward(*docstr):
def docstring_decorator(fn):
docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
intro = f" The {class_name} forward method, overrides the `__call__` special method."
note = r"""
intro = rf""" The {class_name} forward method, overrides the `__call__` special method.
<Tip>
@ -44,7 +55,23 @@ def add_start_docstrings_to_model_forward(*docstr):
</Tip>
"""
fn.__doc__ = intro + note + docstring
correct_indentation = get_docstring_indentation_level(fn)
current_doc = fn.__doc__ if fn.__doc__ is not None else ""
try:
first_non_empty = next(line for line in current_doc.splitlines() if line.strip() != "")
doc_indentation = len(first_non_empty) - len(first_non_empty.lstrip())
except StopIteration:
doc_indentation = correct_indentation
docs = docstr
# In this case, the correct indentation level (class method, 2 Python levels) was respected, and we should
# correctly reindent everything. Otherwise, the doc uses a single indentation level
if doc_indentation == 4 + correct_indentation:
docs = [textwrap.indent(textwrap.dedent(doc), " " * correct_indentation) for doc in docstr]
intro = textwrap.indent(textwrap.dedent(intro), " " * correct_indentation)
docstring = "".join(docs) + current_doc
fn.__doc__ = intro + docstring
return fn
return docstring_decorator
@ -1153,6 +1180,7 @@ def add_code_sample_docstrings(
built_doc = built_doc.replace(
f'from_pretrained("{checkpoint}")', f'from_pretrained("{checkpoint}", revision="{revision}")'
)
fn.__doc__ = func_doc + output_doc + built_doc
return fn

View File

@ -253,10 +253,29 @@ def get_docstring_indent(docstring):
return 0
def is_full_docstring(new_docstring: str) -> bool:
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
be merged with the existing old one.
"""
# libcst returns the docstrinbgs with litteral `r"""` quotes in front
new_docstring = new_docstring.split('"""', 1)[1]
# The docstring contains Args definition, so it is self-contained
if re.search(r"\n\s*Args:\n", new_docstring):
return True
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
if match_object is not None:
full_indent = match_object.group(1)
striped_doc = new_docstring.strip("\n")
if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"):
return True
return False
def merge_docstrings(original_docstring, updated_docstring):
# indent_level = get_docstring_indent(updated_docstring)
original_level = get_docstring_indent(original_docstring)
if not re.findall(r"\n\s*Args:\n", updated_docstring):
if not is_full_docstring(updated_docstring):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1: