Unbreak optimum-executorch (#38646)

* Unbreak optimum-executorch

* use static cache if has layer_types but no sliding_window

* revert view on kv_arange

---------

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang 2025-06-13 02:13:32 -07:00 committed by GitHub
parent 5f59a9b439
commit 7f00b325f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 64 additions and 39 deletions

View File

@ -56,13 +56,15 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
if not hasattr(model.config, "use_cache") or model.config.use_cache is False: if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.") raise ValueError("The model must have caching enabled to be performant.")
if not hasattr(model.config, "layer_types"): if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None:
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
# export will use `StaticCache` by default.
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
self.model = TorchExportableModuleWithStaticCache(model)
else:
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
# there is only 1 type of layers, so export will use `StaticCache` by default.
logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model)
def forward( def forward(
self, self,
@ -400,12 +402,6 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module):
if not self.model.config.use_cache: if not self.model.config.use_cache:
raise AssertionError("Model must have caching enabled") raise AssertionError("Model must have caching enabled")
if (
not hasattr(self.model.config, "cache_implementation")
or self.model.config.cache_implementation != "hybrid"
):
raise AssertionError("Model must use 'hybrid' cache implementation")
# Initialize the HybridCache # Initialize the HybridCache
self.cache = HybridCache( self.cache = HybridCache(
config=self.model.config, config=self.model.config,

View File

@ -378,7 +378,6 @@ class GemmaIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
@ -424,7 +423,10 @@ class GemmaIntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -313,7 +313,6 @@ class Gemma2IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right") tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="</s>", padding_side="right")
@ -363,7 +362,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -306,7 +306,6 @@ class LlamaIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
llama_models = { llama_models = {
@ -352,7 +351,10 @@ class LlamaIntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -334,7 +334,6 @@ class OlmoIntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
olmo_model = "allenai/OLMo-1B-hf" olmo_model = "allenai/OLMo-1B-hf"
@ -382,7 +381,10 @@ class OlmoIntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text)
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -347,7 +347,6 @@ class Phi3IntegrationTest(unittest.TestCase):
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
model_id = "microsoft/Phi-4-mini-instruct" model_id = "microsoft/Phi-4-mini-instruct"
@ -399,7 +398,10 @@ class Phi3IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export()
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -31,7 +31,6 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils.import_utils import is_torch_greater_or_equal
if is_torch_available(): if is_torch_available():
@ -246,7 +245,6 @@ class Qwen2IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
qwen_model = "Qwen/Qwen2-0.5B" qwen_model = "Qwen/Qwen2-0.5B"
@ -287,8 +285,13 @@ class Qwen2IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994 from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exported_program = convert_and_export_with_cache(model, strict=strict)
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
strict = version.parse(torch.__version__) != version.parse(
"2.7.0"
) # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = exportable_module.export(strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -31,7 +31,6 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils.import_utils import is_torch_greater_or_equal
if is_torch_available(): if is_torch_available():
@ -240,13 +239,12 @@ class Qwen3IntegrationTest(unittest.TestCase):
from transformers.integrations.executorch import ( from transformers.integrations.executorch import (
TorchExportableModuleWithStaticCache, TorchExportableModuleWithStaticCache,
convert_and_export_with_cache,
) )
qwen_model = "Qwen/Qwen3-0.6B-Base" qwen_model = "Qwen/Qwen3-0.6B-Base"
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right") tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
if is_torch_greater_or_equal("2.7.0"): if version.parse(torch.__version__) == version.parse("2.7.0"):
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994 strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."] EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."]
else: else:
@ -285,7 +283,10 @@ class Qwen3IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model, strict=strict) from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -15,6 +15,7 @@
import copy import copy
import unittest import unittest
from packaging import version
from parameterized import parameterized from parameterized import parameterized
from transformers import set_seed from transformers import set_seed
@ -680,15 +681,27 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers) self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes using Dim.AUTO # Export with dynamic shapes
tokenizer = AutoTokenizer.from_pretrained(model_id) input_ids = torch.zeros((1, 3), dtype=torch.long)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None} dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
exported_program = convert_and_export_with_cache( exported_program = convert_and_export_with_cache(
model, model,
example_input_ids=input_ids, example_input_ids=input_ids,
example_cache_position=cache_position,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
strict=False, strict=strict,
)
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
exported_program = exportable_module.export(
input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes,
strict=strict,
) )
def test_hybrid_cache_exportability(self): def test_hybrid_cache_exportability(self):
@ -727,13 +740,15 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers) self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes using Dim.AUTO # Export with dynamic shapes using Dim.AUTO
tokenizer = AutoTokenizer.from_pretrained(model_id) input_ids = torch.zeros((1, 3), dtype=torch.long)
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None} dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
exported_program = exportable_module.export( exported_program = exportable_module.export(
input_ids=input_ids, input_ids=input_ids,
cache_position=cache_position,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
strict=False, strict=strict,
) )