mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Llama: always convert the causal mask in the SDPA code path (#29663)
* always convert the mask * rebase and fix copies
This commit is contained in:
parent
5ffef2a978
commit
ee38fc31fb
@ -1005,17 +1005,10 @@ class CohereModel(CoherePreTrainedModel):
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
):
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
@ -1011,17 +1011,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
):
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
@ -1100,17 +1100,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
):
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user