diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index bd4b30a3d12..a16114fe953 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -56,13 +56,15 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): if not hasattr(model.config, "use_cache") or model.config.use_cache is False: raise ValueError("The model must have caching enabled to be performant.") - if not hasattr(model.config, "layer_types"): - # If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so - # export will use `StaticCache` by default. - logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.") - self.model = TorchExportableModuleWithStaticCache(model) - else: + if hasattr(model.config, "layer_types") and getattr(model.config, "sliding_window", None) is not None: self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, + # there is only 1 type of layers, so export will use `StaticCache` by default. + logging.info( + "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." + ) + self.model = TorchExportableModuleWithStaticCache(model) def forward( self, @@ -400,12 +402,6 @@ class TorchExportableModuleWithHybridCache(torch.nn.Module): if not self.model.config.use_cache: raise AssertionError("Model must have caching enabled") - if ( - not hasattr(self.model.config, "cache_implementation") - or self.model.config.cache_implementation != "hybrid" - ): - raise AssertionError("Model must use 'hybrid' cache implementation") - # Initialize the HybridCache self.cache = HybridCache( config=self.model.config, diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index f468d205ab7..3f28fcd7044 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -378,7 +378,6 @@ class GemmaIntegrationTest(unittest.TestCase): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="", padding_side="right") @@ -424,7 +423,10 @@ class GemmaIntegrationTest(unittest.TestCase): self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) # Static Cache + export - exported_program = convert_and_export_with_cache(model) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 808646186c2..0f06ed3cea5 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -313,7 +313,6 @@ class Gemma2IntegrationTest(unittest.TestCase): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", pad_token="", padding_side="right") @@ -363,7 +362,10 @@ class Gemma2IntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index a1e6c944470..743dd0e86f8 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -306,7 +306,6 @@ class LlamaIntegrationTest(unittest.TestCase): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) llama_models = { @@ -352,7 +351,10 @@ class LlamaIntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index ad6363ac679..4e94d231017 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -334,7 +334,6 @@ class OlmoIntegrationTest(unittest.TestCase): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) olmo_model = "allenai/OLMo-1B-hf" @@ -382,7 +381,10 @@ class OlmoIntegrationTest(unittest.TestCase): self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) # Static Cache + export - exported_program = convert_and_export_with_cache(model) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index cb9dc86d43b..1f76a22bffb 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -347,7 +347,6 @@ class Phi3IntegrationTest(unittest.TestCase): from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) model_id = "microsoft/Phi-4-mini-instruct" @@ -399,7 +398,10 @@ class Phi3IntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export() ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index a27695fa9d2..0f846a6a5e3 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -31,7 +31,6 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils.import_utils import is_torch_greater_or_equal if is_torch_available(): @@ -246,7 +245,6 @@ class Qwen2IntegrationTest(unittest.TestCase): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) qwen_model = "Qwen/Qwen2-0.5B" @@ -287,8 +285,13 @@ class Qwen2IntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994 - exported_program = convert_and_export_with_cache(model, strict=strict) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + strict = version.parse(torch.__version__) != version.parse( + "2.7.0" + ) # Due to https://github.com/pytorch/pytorch/issues/150994 + exported_program = exportable_module.export(strict=strict) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 7f4d147cb26..02841dcb6b4 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -31,7 +31,6 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils.import_utils import is_torch_greater_or_equal if is_torch_available(): @@ -240,13 +239,12 @@ class Qwen3IntegrationTest(unittest.TestCase): from transformers.integrations.executorch import ( TorchExportableModuleWithStaticCache, - convert_and_export_with_cache, ) qwen_model = "Qwen/Qwen3-0.6B-Base" tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="", padding_side="right") - if is_torch_greater_or_equal("2.7.0"): + if version.parse(torch.__version__) == version.parse("2.7.0"): strict = False # Due to https://github.com/pytorch/pytorch/issues/150994 EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated."] else: @@ -285,7 +283,10 @@ class Qwen3IntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model, strict=strict) + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export(strict=strict) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 9d435cb7ed1..150bd9e1b2b 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -15,6 +15,7 @@ import copy import unittest +from packaging import version from parameterized import parameterized from transformers import set_seed @@ -680,15 +681,27 @@ class CacheExportIntegrationTest(unittest.TestCase): self.assertEqual(n_static_key_caches, model.config.num_hidden_layers) self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) - # Export with dynamic shapes using Dim.AUTO - tokenizer = AutoTokenizer.from_pretrained(model_id) - input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids - dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None} + # Export with dynamic shapes + input_ids = torch.zeros((1, 3), dtype=torch.long) + cache_position = torch.tensor([0, 1, 2], dtype=torch.long) + dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}} + strict = version.parse(torch.__version__) != version.parse("2.7.0") exported_program = convert_and_export_with_cache( model, example_input_ids=input_ids, + example_cache_position=cache_position, dynamic_shapes=dynamic_shapes, - strict=False, + strict=strict, + ) + + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + exportable_module = TorchExportableModuleForDecoderOnlyLM(model) + exported_program = exportable_module.export( + input_ids=input_ids, + cache_position=cache_position, + dynamic_shapes=dynamic_shapes, + strict=strict, ) def test_hybrid_cache_exportability(self): @@ -727,13 +740,15 @@ class CacheExportIntegrationTest(unittest.TestCase): self.assertEqual(n_g_value_caches, model.config.num_hidden_layers) # Export with dynamic shapes using Dim.AUTO - tokenizer = AutoTokenizer.from_pretrained(model_id) - input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids - dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None} + input_ids = torch.zeros((1, 3), dtype=torch.long) + cache_position = torch.tensor([0, 1, 2], dtype=torch.long) + dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}} + strict = version.parse(torch.__version__) != version.parse("2.7.0") exported_program = exportable_module.export( input_ids=input_ids, + cache_position=cache_position, dynamic_shapes=dynamic_shapes, - strict=False, + strict=strict, )