mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Gemma3 is Torch Exportable (#37728)
* Gemma3 is Torch Exportable * Expand the support to other mdoels using HybridCache --------- Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
parent
397a5ede33
commit
816b37010c
@ -20,15 +20,207 @@ from ..utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import PreTrainedModel, StaticCache
|
||||
from transformers import HybridCache, PreTrainedModel, StaticCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
"""
|
||||
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||
specifically for decoder-only LM with cache. This module ensures that the
|
||||
exported model is compatible with further lowering and execution in `ExecuTorch`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
max_batch_size: int = 1,
|
||||
max_cache_len: int = 4096,
|
||||
):
|
||||
"""
|
||||
Initializes the exportable module with `HybridCache`.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to wrap.
|
||||
max_batch_size (int): Maximum batch size for the cache.
|
||||
max_cache_len (int): Maximum sequence length for the cache.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is configured with a unsupported cache implementation.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if model.config.cache_implementation == "static":
|
||||
self.model = TorchExportableModuleWithStaticCache(model)
|
||||
elif model.config.cache_implementation == "hybrid":
|
||||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits output from the model.
|
||||
"""
|
||||
return self.model.forward(input_ids, cache_position)
|
||||
|
||||
def export(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
dynamic_shapes: Optional[dict] = None,
|
||||
strict: Optional[bool] = None,
|
||||
) -> torch.export.ExportedProgram:
|
||||
"""
|
||||
Export the wrapped module using `torch.export`.
|
||||
|
||||
Args:
|
||||
input_ids (`Optional[torch.Tensor]`):
|
||||
Tensor representing current input token id to the module. If not provided, a default tensor will be used.
|
||||
cache_position (`Optional[torch.Tensor]`):
|
||||
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
|
||||
dynamic_shapes (`Optional[dict]`):
|
||||
Dynamic shapes to use for export if specified.
|
||||
strict(`Optional[bool]`):
|
||||
Flag to instruct `torch.export` to use `torchdynamo`.
|
||||
"""
|
||||
example_input_ids = input_ids if input_ids is not None else torch.tensor([[1]], dtype=torch.long)
|
||||
example_cache_position = cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
|
||||
return torch.export.export(
|
||||
self.model,
|
||||
args=(example_input_ids, example_cache_position),
|
||||
kwargs={},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=strict if strict is not None else True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
max_new_tokens: int = 20,
|
||||
do_sample: bool = False,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
device: str = "cpu",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a sequence of tokens using an exported program.
|
||||
|
||||
Args:
|
||||
exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
|
||||
tokenizer: The tokenizer to use.
|
||||
prompt (str): The input prompt.
|
||||
max_new_tokens (int): Maximum number of new tokens to generate.
|
||||
do_sample (bool): Whether to use sampling or greedy decoding.
|
||||
temperature (float): The temperature for sampling.
|
||||
top_k (int): The number of highest probability tokens to keep for top-k sampling.
|
||||
top_p (float): The cumulative probability for nucleus sampling.
|
||||
device (str): The device to use.
|
||||
|
||||
Returns:
|
||||
str: The generated text.
|
||||
"""
|
||||
# Get the module from the exported program
|
||||
exported_module = exported_program.module()
|
||||
|
||||
# Tokenize the prompt
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
||||
|
||||
# Initialize with the prompt
|
||||
generated_ids = input_ids.clone()
|
||||
|
||||
# Process the prompt tokens first
|
||||
curr_position = 0
|
||||
for i in range(input_ids.shape[1]):
|
||||
# Process one token at a time
|
||||
curr_input_ids = input_ids[:, i : i + 1]
|
||||
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
|
||||
|
||||
# Forward pass
|
||||
_ = exported_module(curr_input_ids, curr_cache_position)
|
||||
curr_position += 1
|
||||
|
||||
# Generate new tokens
|
||||
for _ in range(max_new_tokens):
|
||||
# Get the last token as input
|
||||
curr_input_ids = generated_ids[:, -1:]
|
||||
curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
|
||||
|
||||
# Forward pass to get next token logits
|
||||
outputs = exported_module(curr_input_ids, curr_cache_position)
|
||||
|
||||
# Get the next token ID
|
||||
if do_sample:
|
||||
# Apply temperature
|
||||
if temperature > 0:
|
||||
logits = outputs / temperature
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
# Apply top-k filtering
|
||||
if top_k > 0:
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Apply top-p (nucleus) filtering
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# Scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = float("-inf")
|
||||
|
||||
# Sample from the filtered distribution
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token_id = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
# Greedy decoding
|
||||
next_token_id = outputs.argmax(dim=-1, keepdim=True)
|
||||
|
||||
# Ensure next_token_id has the right shape before concatenation
|
||||
if next_token_id.dim() > 2:
|
||||
next_token_id = next_token_id.squeeze(-1)
|
||||
|
||||
# Append to the generated sequence
|
||||
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
|
||||
curr_position += 1
|
||||
|
||||
# Stop if we generate an EOS token
|
||||
if next_token_id.item() == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Decode the generated text
|
||||
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
"""
|
||||
A wrapper module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||
specifically for use with static caching. This module ensures that the exported model
|
||||
is compatible with further lowering and execution in `ExecuTorch`.
|
||||
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||
specifically for decoder-only LM to `StaticCache`. This module ensures that the
|
||||
exported model is compatible with further lowering and execution in `ExecuTorch`.
|
||||
|
||||
Note:
|
||||
This class is specifically designed to support export process using `torch.export`
|
||||
@ -178,6 +370,94 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
return torch.tensor([response_tokens], dtype=torch.long)
|
||||
|
||||
|
||||
class TorchExportableModuleWithHybridCache(torch.nn.Module):
|
||||
"""
|
||||
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
|
||||
specifically for decoder-only LM to `HybridCache`. This module ensures that the
|
||||
exported model is compatible with further lowering and execution in `ExecuTorch`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
max_batch_size: int = 1,
|
||||
max_cache_len: int = 4096,
|
||||
):
|
||||
"""
|
||||
Initializes the exportable module with `HybridCache`.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to wrap.
|
||||
max_batch_size (int): Maximum batch size for the cache.
|
||||
max_cache_len (int): Maximum sequence length for the cache.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model doesn't have the expected configuration for HybridCache.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
# Verify the model is configured for HybridCache
|
||||
if not self.model.config.use_cache:
|
||||
raise AssertionError("Model must have caching enabled")
|
||||
|
||||
if (
|
||||
not hasattr(self.model.config, "cache_implementation")
|
||||
or self.model.config.cache_implementation != "hybrid"
|
||||
):
|
||||
raise AssertionError("Model must use 'hybrid' cache implementation")
|
||||
|
||||
# Initialize the HybridCache
|
||||
self.cache = HybridCache(
|
||||
config=self.model.config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_cache_len=max_cache_len,
|
||||
device=self.model.device,
|
||||
dtype=self.model.dtype,
|
||||
)
|
||||
|
||||
# Register all key and value cache tensors as buffers
|
||||
for i in range(len(self.cache.key_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of the module, which is compatible with the ExecuTorch llm runner.
|
||||
|
||||
Args:
|
||||
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
|
||||
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits output from the model.
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Generate position_ids from cache_position
|
||||
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Create attention mask (always ones for token-by-token generation)
|
||||
attention_mask = torch.ones((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
|
||||
|
||||
# Forward pass with the model
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=self.cache,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
# Return only the logits to simplify the export
|
||||
return outputs.logits
|
||||
|
||||
|
||||
def convert_and_export_with_cache(
|
||||
model: PreTrainedModel,
|
||||
example_input_ids: Optional[torch.Tensor] = None,
|
||||
|
@ -351,7 +351,7 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer):
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
|
@ -400,7 +400,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer):
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
|
@ -317,7 +317,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
|
@ -364,7 +364,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
|
@ -410,7 +410,7 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
|
@ -494,7 +494,7 @@ class Gemma3DecoderLayer(nn.Module):
|
||||
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
|
||||
offset = cache_position[-1] - effective_seq_len + 1
|
||||
# Should only be used when beyond the sliding window (i.e. offset > 0)
|
||||
offset = max(0, offset)
|
||||
offset = torch.clamp(offset, min=0)
|
||||
# equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
|
||||
# but without data-dependent slicing (i.e. torch.compile friendly)
|
||||
mask_indexes = torch.arange(
|
||||
|
@ -337,6 +337,44 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_export_hybrid_cache(self):
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal
|
||||
|
||||
if not is_torch_greater_or_equal("2.6.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||
|
||||
model_id = "google/gemma-2-2b"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
|
||||
# Test generation with the exported model
|
||||
prompt = "What is the capital of France?"
|
||||
max_new_tokens_to_generate = 20
|
||||
# Generate text with the exported model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
|
||||
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
|
||||
)
|
||||
|
||||
input_text = tokenizer(prompt, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
eager_outputs = model.generate(
|
||||
**input_text,
|
||||
max_new_tokens=max_new_tokens_to_generate,
|
||||
do_sample=False, # Use greedy decoding to match the exported model
|
||||
)
|
||||
|
||||
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
|
||||
self.assertEqual(export_generated_text, eager_generated_text)
|
||||
|
||||
@require_read_token
|
||||
@tooslow
|
||||
def test_model_9b_bf16_flex_attention(self):
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Gemma3 model."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -52,6 +53,7 @@ if is_torch_available():
|
||||
Gemma3Processor,
|
||||
Gemma3TextModel,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
class Gemma3ModelTester(GemmaModelTester):
|
||||
@ -664,3 +666,42 @@ class Gemma3IntegrationTest(unittest.TestCase):
|
||||
model.generation_config.transformers_version = "4.49.0"
|
||||
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
|
||||
out = model.generate(**inputs, generation_config=generation_config)
|
||||
|
||||
def test_export_text_only_with_hybrid_cache(self):
|
||||
if not is_torch_greater_or_equal("2.6.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
model_id = "google/gemma-3-1b-it"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
|
||||
exported_program = exportable_module.export()
|
||||
logging.info(f"\nExported program: {exported_program}")
|
||||
|
||||
# Test generation with the exported model
|
||||
prompt = "What is the capital of France?"
|
||||
max_new_tokens_to_generate = 20
|
||||
# Generate text with the exported model
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
|
||||
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
|
||||
)
|
||||
logging.info(f"\nExport generated texts: '{export_generated_text}'")
|
||||
|
||||
input_text = tokenizer(prompt, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
eager_outputs = model.generate(
|
||||
**input_text,
|
||||
max_new_tokens=max_new_tokens_to_generate,
|
||||
do_sample=False, # Use greedy decoding to match the exported model
|
||||
)
|
||||
|
||||
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
|
||||
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
|
||||
|
||||
self.assertEqual(export_generated_text, eager_generated_text)
|
||||
|
Loading…
Reference in New Issue
Block a user