Fix torch.compile with fullgraph=True when attention_mask input is used (#29211)

* fix torch.export.export for llama

* do not change doc title

* make fix copies
This commit is contained in:
fxmarty 2024-02-22 16:40:06 +01:00 committed by GitHub
parent dabe855668
commit 2cc8cf6ce7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 39 additions and 33 deletions

View File

@ -184,7 +184,7 @@ For now, Transformers supports SDPA inference and training for the following arc
<Tip>
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first.
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
</Tip>

View File

@ -529,24 +529,6 @@ And for Pytorch DeepSpeed has built one as well: [DeepSpeed-MoE: Advancing Mixtu
## Using PyTorch native attention and Flash Attention
PyTorch 2.0 released a native [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA),
that allows using fused GPU kernels such as [memory-efficient attention](https://arxiv.org/abs/2112.05682) and [flash attention](https://arxiv.org/abs/2205.14135).
After installing the [`optimum`](https://github.com/huggingface/optimum) package, the relevant internal modules can be
replaced to use PyTorch's native attention with:
```python
model = model.to_bettertransformer()
```
Once converted, train the model as usual.
<Tip warning={true}>
The PyTorch-native `scaled_dot_product_attention` operator can only dispatch to Flash Attention if no `attention_mask` is provided.
By default, in training mode, the BetterTransformer integration **drops the mask support and can only be used for training that does not require a padding mask for batched training**. This is the case, for example, during masked language modeling or causal language modeling. BetterTransformer is not suited for fine-tuning models on tasks that require a padding mask.
</Tip>
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. Please refer to [PyTorch scaled dot product attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) for a list of supported models and more details.
Check out this [blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to learn more about acceleration and memory-savings with SDPA.

View File

@ -349,8 +349,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(inputs_embeds, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if attention_mask is not None:
# 4d mask is passed through
@ -448,10 +452,14 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype,
batch_size, key_value_length = mask.shape
tgt_len = tgt_len if tgt_len is not None else key_value_length
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
is_tracing = torch.jit.is_tracing()
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(mask, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if torch.all(mask == 1):
if is_tracing:

View File

@ -969,10 +969,18 @@ class GemmaModel(GemmaPreTrainedModel):
padding_mask, torch.finfo(dtype).min
)
if self.config._attn_implementation == "sdpa":
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
dtype
)

View File

@ -1076,10 +1076,18 @@ class LlamaModel(LlamaPreTrainedModel):
padding_mask, torch.finfo(dtype).min
)
if self.config._attn_implementation == "sdpa":
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1, keepdim=True)).to(
dtype
)