Gemma3 is Torch Exportable (#37728)

* Gemma3 is Torch Exportable

* Expand the support to other mdoels using HybridCache

---------

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang 2025-04-28 00:36:46 -07:00 committed by GitHub
parent 397a5ede33
commit 816b37010c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 369 additions and 10 deletions

View File

@ -20,15 +20,207 @@ from ..utils.import_utils import is_torch_available
if is_torch_available():
from transformers import PreTrainedModel, StaticCache
from transformers import HybridCache, PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for decoder-only LM with cache. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.
"""
def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module with `HybridCache`.
Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Raises:
ValueError: If the model is configured with a unsupported cache implementation.
"""
super().__init__()
if model.config.cache_implementation == "static":
self.model = TorchExportableModuleWithStaticCache(model)
elif model.config.cache_implementation == "hybrid":
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
raise ValueError(
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
)
def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
"""
return self.model.forward(input_ids, cache_position)
def export(
self,
input_ids: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
dynamic_shapes: Optional[dict] = None,
strict: Optional[bool] = None,
) -> torch.export.ExportedProgram:
"""
Export the wrapped module using `torch.export`.
Args:
input_ids (`Optional[torch.Tensor]`):
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
cache_position (`Optional[torch.Tensor]`):
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
dynamic_shapes (`Optional[dict]`):
Dynamic shapes to use for export if specified.
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
"""
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
return torch.export.export(
self.model,
args=(example_input_ids, example_cache_position),
kwargs={},
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True,
)
@staticmethod
def generate(
exported_program: torch.export.ExportedProgram,
tokenizer,
prompt: str,
max_new_tokens: int = 20,
do_sample: bool = False,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 1.0,
device: str = "cpu",
) -> str:
"""
Generate a sequence of tokens using an exported program.
Args:
exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
tokenizer: The tokenizer to use.
prompt (str): The input prompt.
max_new_tokens (int): Maximum number of new tokens to generate.
do_sample (bool): Whether to use sampling or greedy decoding.
temperature (float): The temperature for sampling.
top_k (int): The number of highest probability tokens to keep for top-k sampling.
top_p (float): The cumulative probability for nucleus sampling.
device (str): The device to use.
Returns:
str: The generated text.
"""
# Get the module from the exported program
exported_module = exported_program.module()
# Tokenize the prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
# Initialize with the prompt
generated_ids = input_ids.clone()
# Process the prompt tokens first
curr_position = 0
for i in range(input_ids.shape[1]):
# Process one token at a time
curr_input_ids = input_ids[:, i : i + 1]
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
# Forward pass
_ = exported_module(curr_input_ids, curr_cache_position)
curr_position += 1
# Generate new tokens
for _ in range(max_new_tokens):
# Get the last token as input
curr_input_ids = generated_ids[:, -1:]
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
# Forward pass to get next token logits
outputs = exported_module(curr_input_ids, curr_cache_position)
# Get the next token ID
if do_sample:
# Apply temperature
if temperature > 0:
logits = outputs / temperature
else:
logits = outputs
# Apply top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float("-inf")
# Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float("-inf")
# Sample from the filtered distribution
probs = torch.softmax(logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
else:
# Greedy decoding
next_token_id = outputs.argmax(dim=-1, keepdim=True)
# Ensure next_token_id has the right shape before concatenation
if next_token_id.dim() > 2:
next_token_id = next_token_id.squeeze(-1)
# Append to the generated sequence
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
curr_position += 1
# Stop if we generate an EOS token
if next_token_id.item() == tokenizer.eos_token_id:
break
# Decode the generated text
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
class TorchExportableModuleWithStaticCache(torch.nn.Module):
"""
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for use with static caching. This module ensures that the exported model
is compatible with further lowering and execution in `ExecuTorch`.
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for decoder-only LM to `StaticCache`. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.
Note:
This class is specifically designed to support export process using `torch.export`
@ -178,6 +370,94 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
return torch.tensor([response_tokens], dtype=torch.long)
class TorchExportableModuleWithHybridCache(torch.nn.Module):
"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for decoder-only LM to `HybridCache`. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.
"""
def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module with `HybridCache`.
Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.
Raises:
AssertionError: If the model doesn't have the expected configuration for HybridCache.
"""
super().__init__()
self.model = model
# Verify the model is configured for HybridCache
if not self.model.config.use_cache:
raise AssertionError("Model must have caching enabled")
if (
not hasattr(self.model.config, "cache_implementation")
or self.model.config.cache_implementation != "hybrid"
):
raise AssertionError("Model must use 'hybrid' cache implementation")
# Initialize the HybridCache
self.cache = HybridCache(
config=self.model.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.model.device,
dtype=self.model.dtype,
)
# Register all key and value cache tensors as buffers
for i in range(len(self.cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
Returns:
torch.Tensor: Logits output from the model.
"""
batch_size, seq_len = input_ids.shape
# Generate position_ids from cache_position
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
# Create attention mask (always ones for token-by-token generation)
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
# Forward pass with the model
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=self.cache,
use_cache=True,
cache_position=cache_position,
)
# Return only the logits to simplify the export
return outputs.logits
def convert_and_export_with_cache(
model: PreTrainedModel,
example_input_ids: Optional[torch.Tensor] = None,

View File

@ -351,7 +351,7 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(

View File

@ -400,7 +400,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(

View File

@ -317,7 +317,7 @@ class Gemma2DecoderLayer(nn.Module):
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(

View File

@ -364,7 +364,7 @@ class Gemma2DecoderLayer(nn.Module):
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(

View File

@ -410,7 +410,7 @@ class Gemma3DecoderLayer(nn.Module):
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(

View File

@ -494,7 +494,7 @@ class Gemma3DecoderLayer(nn.Module):
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
offset = torch.clamp(offset, min=0)
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indexes = torch.arange(

View File

@ -337,6 +337,44 @@ class Gemma2IntegrationTest(unittest.TestCase):
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@slow
@require_read_token
def test_export_hybrid_cache(self):
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
from transformers.pytorch_utils import is_torch_greater_or_equal
if not is_torch_greater_or_equal("2.6.0"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")
model_id = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
# Test generation with the exported model
prompt = "What is the capital of France?"
max_new_tokens_to_generate = 20
# Generate text with the exported model
tokenizer = AutoTokenizer.from_pretrained(model_id)
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
)
input_text = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
eager_outputs = model.generate(
**input_text,
max_new_tokens=max_new_tokens_to_generate,
do_sample=False, # Use greedy decoding to match the exported model
)
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
self.assertEqual(export_generated_text, eager_generated_text)
@require_read_token
@tooslow
def test_model_9b_bf16_flex_attention(self):

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Testing suite for the PyTorch Gemma3 model."""
import logging
import tempfile
import unittest
@ -52,6 +53,7 @@ if is_torch_available():
Gemma3Processor,
Gemma3TextModel,
)
from transformers.pytorch_utils import is_torch_greater_or_equal
class Gemma3ModelTester(GemmaModelTester):
@ -664,3 +666,42 @@ class Gemma3IntegrationTest(unittest.TestCase):
model.generation_config.transformers_version = "4.49.0"
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
out = model.generate(**inputs, generation_config=generation_config)
def test_export_text_only_with_hybrid_cache(self):
if not is_torch_greater_or_equal("2.6.0"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
model_id = "google/gemma-3-1b-it"
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + HybridCache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
logging.info(f"\nExported program: {exported_program}")
# Test generation with the exported model
prompt = "What is the capital of France?"
max_new_tokens_to_generate = 20
# Generate text with the exported model
tokenizer = AutoTokenizer.from_pretrained(model_id)
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
)
logging.info(f"\nExport generated texts: '{export_generated_text}'")
input_text = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
eager_outputs = model.generate(
**input_text,
max_new_tokens=max_new_tokens_to_generate,
do_sample=False, # Use greedy decoding to match the exported model
)
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
self.assertEqual(export_generated_text, eager_generated_text)