Fix torch.fx symbolic tracing for LLama (#30047)

* [WIP] fix fx

* [WIP] fix fx

* [WIP] fix fx

* [WIP] fix fx

* [WIP] fix fx

* Apply changes to other models
This commit is contained in:
Michael Benayoun 2024-04-05 15:14:09 +02:00 committed by GitHub
parent 48795317a2
commit 17cd7a9d28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 23 additions and 18 deletions

View File

@ -908,7 +908,9 @@ class CohereModel(CoherePreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
# embed positions
hidden_states = inputs_embeds
@ -976,7 +978,7 @@ class CohereModel(CoherePreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
@ -989,7 +991,7 @@ class CohereModel(CoherePreTrainedModel):
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

View File

@ -888,7 +888,9 @@ class GemmaModel(GemmaPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
# embed positions
hidden_states = inputs_embeds
@ -962,7 +964,7 @@ class GemmaModel(GemmaPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
@ -975,7 +977,7 @@ class GemmaModel(GemmaPreTrainedModel):
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

View File

@ -987,7 +987,9 @@ class LlamaModel(LlamaPreTrainedModel):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
# embed positions
hidden_states = inputs_embeds
@ -1055,7 +1057,7 @@ class LlamaModel(LlamaPreTrainedModel):
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, current_length):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
@ -1068,7 +1070,7 @@ class LlamaModel(LlamaPreTrainedModel):
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else current_length + 1
)
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)

View File

@ -260,11 +260,14 @@ def torch_arange(*args, **kwargs):
def torch_full(*args, **kwargs):
args = list(args)
if isinstance(args[1], torch.Tensor) and args[1].device == torch.device("meta"):
args[1] = 1 # Any value.
# We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
if len(args) > 1:
args[1] = 1
else:
kwargs["fill_value"] = 1
kwargs_without_device = dict(kwargs)
kwargs_without_device.pop("device", None)
return torch.full(*args, **kwargs_without_device)
return torch.full(*args, **kwargs_without_device, device="meta")
def torch_cat(tensors, dim=None, axis=None, *, out=None):

View File

@ -283,9 +283,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
test_headmasking = False
test_pruning = False
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer

View File

@ -305,9 +305,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
test_headmasking = False
test_pruning = False
fx_compatible = (
False # FIXME @michaelbenayoun or @fxmarty from https://github.com/huggingface/transformers/pull/29753
)
fx_compatible = True
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer