Make RwkvModel accept attention_mask but discard it internally (#23442)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-05-18 17:14:25 +02:00 committed by GitHub
parent cf43200861
commit 21f7e81b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -565,6 +565,15 @@ RWKV_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
This is currently not used by `RwkvModel`, but will be supported in the future.
[What are attention masks?](../glossary#attention-mask)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
@ -617,6 +626,7 @@ class RwkvModel(RwkvPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
@ -750,7 +760,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,