[docs] update input documentation for MAMBA2 and MISTRAL models to include cache_position and attention_mask details (#34322)

* [docs] update input documentation for MAMBA2 and MISTRAL models to include cache_position and attention_mask details

* [docs] correct input documentation for MISTRAL model to reference `input_ids` instead of `decoder_input_ids`

* [docs] clarify cache_position description in MISTRAL model documentation
This commit is contained in:
Vijay 2024-10-28 21:44:07 +05:30 committed by GitHub
parent c1753436db
commit fc1ae7f30f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 1 deletions

View File

@ -805,6 +805,16 @@ MAMBA2_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
If `cache_params` is passed, `cache_position` should also be passed.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""

View File

@ -619,7 +619,7 @@ MISTRAL_INPUTS_DOCSTRING = r"""
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
@ -666,6 +666,10 @@ MISTRAL_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""