fix a typo in flax T5 attention - attention_mask variable is misnamed (#26663)

* fix a typo in flax t5 attention

* fix the typo in flax longt5 attention
This commit is contained in:
théo gigant 2023-10-10 20:36:32 +02:00 committed by GitHub
parent e8fdd7875d
commit 975003eacb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -545,7 +545,7 @@ class FlaxLongT5Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)

View File

@ -405,7 +405,7 @@ class FlaxT5Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_attention_mask = self._concatenate_to_cache(
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)