mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
symbolic_trace: add past_key_values, llama, sdpa support (#28447)
* torch.fx: add pkv, llama, sdpa support * Update src/transformers/models/opt/modeling_opt.py * remove spaces * trigger ci * use explicit variable names
This commit is contained in:
parent
09eb11a1bd
commit
a6adc05e6b
@ -132,6 +132,7 @@ class AttentionMaskConverter:
|
||||
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
||||
attention_mask_2d.device
|
||||
)
|
||||
|
||||
if causal_4d_mask is not None:
|
||||
expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
|
||||
|
||||
@ -346,10 +347,10 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
key_value_length = input_shape[-1] + past_key_values_length
|
||||
batch_size, query_length = input_shape
|
||||
|
||||
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
||||
# torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
|
||||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
|
||||
# TODO: Fix this as well when using torchdynamo with fullgraph=True.
|
||||
is_tracing = torch.jit.is_tracing()
|
||||
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)
|
||||
|
||||
if attention_mask is not None:
|
||||
# 4d mask is passed through
|
||||
@ -367,10 +368,8 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
)
|
||||
return attention_mask
|
||||
|
||||
elif torch.all(attention_mask == 1):
|
||||
if is_tracing:
|
||||
pass
|
||||
elif query_length == 1:
|
||||
elif not is_tracing and torch.all(attention_mask == 1):
|
||||
if query_length == 1:
|
||||
# For query_length == 1, causal attention and bi-directional attention are the same.
|
||||
attention_mask = None
|
||||
elif key_value_length == query_length:
|
||||
@ -405,7 +404,11 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if query_length > 1:
|
||||
#
|
||||
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
|
||||
# controlflow that can not be captured properly.
|
||||
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
|
||||
if query_length > 1 and not is_tracing:
|
||||
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
|
||||
expanded_4d_mask, attention_mask, unmasked_value=0.0
|
||||
)
|
||||
|
@ -131,6 +131,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
"gptj",
|
||||
"hubert",
|
||||
"layoutlm",
|
||||
"llama",
|
||||
"lxmert",
|
||||
"m2m_100",
|
||||
"marian",
|
||||
@ -156,6 +157,8 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
|
||||
# "xlnet",
|
||||
]
|
||||
|
||||
_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]
|
||||
|
||||
_REGULAR_SUPPORTED_MODELS = []
|
||||
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
|
||||
if isinstance(item, dict):
|
||||
@ -514,6 +517,14 @@ def torch_nn_functional_one_hot(tensor, num_classes=-1):
|
||||
return torch.empty(shape, device="meta")
|
||||
|
||||
|
||||
def torch_nn_functional_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
|
||||
):
|
||||
target_length = query.shape[-2]
|
||||
head_dim = value.shape[-1]
|
||||
return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")
|
||||
|
||||
|
||||
def torch_nn_mseloss(self, input, target):
|
||||
if self.reduction == "none":
|
||||
shape = target.shape
|
||||
@ -597,6 +608,7 @@ _MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
|
||||
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
|
||||
torch.unique_consecutive: torch_unique_consecutive,
|
||||
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
|
||||
torch.nn.functional.scaled_dot_product_attention: torch_nn_functional_scaled_dot_product_attention,
|
||||
torch.nn.MSELoss: torch_nn_mseloss,
|
||||
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
|
||||
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
|
||||
@ -868,6 +880,23 @@ class HFTracer(Tracer):
|
||||
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
|
||||
elif "mask" in input_name or "ids" in input_name:
|
||||
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
|
||||
elif "past_key_values" in input_name:
|
||||
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||||
raise NotImplementedError(
|
||||
f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
|
||||
)
|
||||
num_heads = model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (shape[0], num_heads, 0, head_dim)
|
||||
pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
inputs_dict[input_name] = pkv
|
||||
else:
|
||||
shape_with_hidden_size = shape + [model.config.hidden_size]
|
||||
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
|
||||
|
@ -292,6 +292,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = LlamaModelTester(self)
|
||||
|
@ -118,7 +118,7 @@ if is_flax_available():
|
||||
)
|
||||
|
||||
if is_torch_fx_available():
|
||||
from transformers.utils.fx import symbolic_trace
|
||||
from transformers.utils.fx import _FX_SUPPORTED_MODELS_WITH_KV_CACHE, symbolic_trace
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
@ -1004,7 +1004,9 @@ class ModelTesterMixin:
|
||||
|
||||
def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=False):
|
||||
if not is_torch_fx_available() or not self.fx_compatible:
|
||||
return
|
||||
self.skipTest(
|
||||
f"Either torch.fx is not available, or the model type {config.model_type} is not compatible with torch.fx"
|
||||
)
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.return_dict = False
|
||||
@ -1060,6 +1062,26 @@ class ModelTesterMixin:
|
||||
if end_positions is not None:
|
||||
input_names.append("end_positions")
|
||||
|
||||
if model.config.model_type in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
|
||||
input_names.append("past_key_values")
|
||||
|
||||
# Generally model_tester.prepare_config_and_inputs_for_common seem not to generate past key values inputs.
|
||||
if "past_key_values" not in inputs:
|
||||
batch_size = inputs[next(iter(inputs))].shape[0]
|
||||
num_heads = model.config.num_attention_heads
|
||||
head_dim = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
cache_shape = (batch_size, num_heads, 0, head_dim)
|
||||
pkv = tuple(
|
||||
(
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
for i in range(model.config.num_hidden_layers)
|
||||
)
|
||||
|
||||
inputs["past_key_values"] = pkv
|
||||
|
||||
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
|
||||
input_names = list(filtered_inputs.keys())
|
||||
|
||||
@ -1069,8 +1091,10 @@ class ModelTesterMixin:
|
||||
model.config.problem_type = "single_label_classification"
|
||||
|
||||
traced_model = symbolic_trace(model, input_names)
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
traced_output = traced_model(**filtered_inputs)
|
||||
model_output = model(**filtered_inputs)
|
||||
|
||||
except Exception as e:
|
||||
self.fail(f"Couldn't trace module: {e}")
|
||||
|
Loading…
Reference in New Issue
Block a user