mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix new FA2 if is_causal
is passed explicitly (#35390)
* fix * Update modeling_decision_transformer.py * Update flash_attention.py
This commit is contained in:
parent
8f38f58f3d
commit
05260a1fc1
@ -44,6 +44,9 @@ def flash_attention_forward(
|
||||
else:
|
||||
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
|
||||
|
||||
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
|
||||
kwargs.pop("is_causal", None)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query,
|
||||
key,
|
||||
|
@ -285,9 +285,9 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
||||
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
||||
|
||||
query_states = query_states.reshape(shape_q).transpose(1, 2)
|
||||
key_states = key_states.reshape(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.reshape(shape_kv).transpose(1, 2)
|
||||
query_states = query_states.view(shape_q).transpose(1, 2)
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
|
@ -295,9 +295,9 @@ class GPT2Attention(nn.Module):
|
||||
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
||||
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
||||
|
||||
query_states = query_states.reshape(shape_q).transpose(1, 2)
|
||||
key_states = key_states.reshape(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.reshape(shape_kv).transpose(1, 2)
|
||||
query_states = query_states.view(shape_q).transpose(1, 2)
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
|
Loading…
Reference in New Issue
Block a user