mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Make RwkvModel
accept attention_mask
but discard it internally (#23442)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
cf43200861
commit
21f7e81b6b
@ -565,6 +565,15 @@ RWKV_INPUTS_DOCSTRING = r"""
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
This is currently not used by `RwkvModel`, but will be supported in the future.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
@ -617,6 +626,7 @@ class RwkvModel(RwkvPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
state: Optional[List[torch.FloatTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -750,7 +760,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
state: Optional[List[torch.FloatTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
|
Loading…
Reference in New Issue
Block a user