Fix new FA2 if is_causal is passed explicitly (#35390)

* fix

* Update modeling_decision_transformer.py

* Update flash_attention.py
This commit is contained in:
Cyril Vallez 2024-12-22 20:00:07 +01:00 committed by GitHub
parent 8f38f58f3d
commit 05260a1fc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 6 deletions

View File

@ -44,6 +44,9 @@ def flash_attention_forward(
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)
attn_output = _flash_attention_forward(
query,
key,

View File

@ -285,9 +285,9 @@ class DecisionTransformerGPT2Attention(nn.Module):
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
query_states = query_states.reshape(shape_q).transpose(1, 2)
key_states = key_states.reshape(shape_kv).transpose(1, 2)
value_states = value_states.reshape(shape_kv).transpose(1, 2)
query_states = query_states.view(shape_q).transpose(1, 2)
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)
if layer_past is not None:
past_key, past_value = layer_past

View File

@ -295,9 +295,9 @@ class GPT2Attention(nn.Module):
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
query_states = query_states.reshape(shape_q).transpose(1, 2)
key_states = key_states.reshape(shape_kv).transpose(1, 2)
value_states = value_states.reshape(shape_kv).transpose(1, 2)
query_states = query_states.view(shape_q).transpose(1, 2)
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)
if layer_past is not None:
past_key, past_value = layer_past