Generate using exported model and enable gemma2-2b in ExecuTorch (#33707)

* Generate using exported model and enable gemma2-2b in ExecuTorch

* [run_slow] gemma, gemma2

* truncate expected output message

* Bump required torch version to support gemma2 export

* [run_slow] gemma, gemma2

---------

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang 2024-10-11 01:16:31 -07:00 committed by GitHub
parent 70b07d97cf
commit 7d97cca8dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 168 additions and 0 deletions

View File

@ -114,6 +114,56 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
)
return outs.logits
@staticmethod
def generate(
exported_program: torch.export.ExportedProgram, prompt_token_ids: torch.Tensor, max_new_tokens: int
) -> torch.Tensor:
"""
Generate a sequence of tokens using an exported program.
This util function is designed to test exported models by simulating the generation process.
It processes the input prompt tokens sequentially (no parallel prefill).
This generate function is not intended to replace the original `generate` method, and the support
for leveraging the original `generate` is potentially planed!
Args:
exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs.
max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation
length is limited by both `max_new_tokens` and the model's cache size.
Returns:
torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
"""
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break
response_tokens = []
for input_pos in range(min(max_generation_length, prompt_token_len)):
result = exported_program.module().forward(
input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
cache_position=torch.tensor([input_pos], dtype=torch.long),
)
response_tokens.append(prompt_token_ids[0][input_pos].item())
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)
while len(response_tokens) < max_generation_length:
result = exported_program.module().forward(
input_ids=torch.tensor([[current_token]], dtype=torch.long),
cache_position=torch.tensor([len(response_tokens)], dtype=torch.long),
)
current_token = torch.argmax(result[:, -1, :], dim=-1).item()
response_tokens.append(current_token)
return torch.tensor([response_tokens], dtype=torch.long)
def convert_and_export_with_cache(
model: PreTrainedModel,

View File

@ -21,6 +21,7 @@ import pytest
from packaging import version
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
is_flaky,
require_bitsandbytes,
@ -841,6 +842,67 @@ class GemmaIntegrationTest(unittest.TestCase):
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow
@require_read_token
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music was in the 1990s. I have looked on the internet and I have found",
]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]
# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = GemmaForCausalLM.from_pretrained(
"google/gemma-2b",
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
)
prompts = ["Hello I am doing"]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + eager
eager_generated_ids = model.generate(
**prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation
)
eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
# Static Cache + export
exported_program = convert_and_export_with_cache(model)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
def test_model_2b_bf16_dola(self):
model_id = "google/gemma-2b"
# ground truth text generated with dola_layers="low", repetition_penalty=1.2

View File

@ -16,10 +16,12 @@
import unittest
from packaging import version
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
from transformers.generation.configuration_utils import GenerationConfig
from transformers.testing_utils import (
require_flash_attn,
require_read_token,
@ -306,3 +308,57 @@ class Gemma2IntegrationTest(unittest.TestCase):
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)
@slow
@require_read_token
def test_export_static_cache(self):
if version.parse(torch.__version__) < version.parse("2.5.0"):
self.skipTest(reason="This test requires torch >= 2.5 to run.")
from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = [
"Hello I am doing a project for my school and I need to know how to make a program that will take a number",
]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]
# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
device_map=device,
torch_dtype=dtype,
attn_implementation=attn_implementation,
generation_config=GenerationConfig(
use_cache=True,
cache_implementation=cache_implementation,
max_length=max_generation_length,
cache_config={
"batch_size": batch_size,
"max_cache_len": max_generation_length,
},
),
)
prompts = ["Hello I am doing"]
prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
prompt_token_ids = prompt_tokens["input_ids"]
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export
exported_program = convert_and_export_with_cache(model)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)