diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index a5bae1fe9b7..88dfe4640ca 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -10,6 +10,7 @@ # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +import logging from typing import Optional import torch @@ -50,14 +51,22 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module): """ super().__init__() - if model.config.cache_implementation == "static": + if not hasattr(model.config, "use_cache") or model.config.use_cache is False: + raise ValueError("The model must have caching enabled to be performant.") + + if not hasattr(model.config, "cache_implementation"): + # If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will + # be used by default, so export will use `StaticCache` by default. + logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.") self.model = TorchExportableModuleWithStaticCache(model) - elif model.config.cache_implementation == "hybrid": - self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) else: - raise ValueError( - f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'" - ) + if model.config.cache_implementation == "hybrid": + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + raise ValueError( + f"Unsupported cache implementation: {model.config.cache_implementation}. " + "Please use `hybrid` or `static`." + ) def forward( self, @@ -462,6 +471,8 @@ def convert_and_export_with_cache( model: PreTrainedModel, example_input_ids: Optional[torch.Tensor] = None, example_cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, ): """ Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`, @@ -469,8 +480,10 @@ def convert_and_export_with_cache( Args: model (`PreTrainedModel`): The pretrained model to be exported. - example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`. - example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`. + example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`. + example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`. + dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`. + strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`. Returns: Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`. @@ -489,14 +502,21 @@ def convert_and_export_with_cache( example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long) ) - if is_torch_greater_or_equal("2.5.0"): + if is_torch_greater_or_equal("2.6.0"): exported_program = torch.export.export( TorchExportableModuleWithStaticCache(model), - args=(example_input_ids,), - kwargs={"cache_position": example_cache_position}, - strict=True, + args=(example_input_ids, example_cache_position), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, ) else: + if dynamic_shapes is not None: + logging.warning( + "Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0." + ) + if strict is not None: + logging.warning("The strict flag will be ingored by convert_and_export_with_cache for torch < 2.6.0.") # We have to keep this path for BC. # # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index e5b43bec921..3fbc2993555 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -25,7 +25,6 @@ from transformers.testing_utils import ( is_torch_available, require_gptq, require_non_xpu, - require_read_token, require_torch, require_torch_accelerator, require_torch_gpu, @@ -693,8 +692,6 @@ class CacheExportIntegrationTest(unittest.TestCase): for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): self.assertTrue(torch.allclose(v1, v2)) - @slow - @require_read_token def test_static_cache_exportability(self): """ Tests that static cache works with `torch.export()` @@ -709,8 +706,9 @@ class CacheExportIntegrationTest(unittest.TestCase): attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention batch_size = 1 max_cache_len = 1234 + model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM" model = AutoModelForCausalLM.from_pretrained( - "google/gemma-2b", + model_id, device_map=device, torch_dtype=dtype, attn_implementation=attn_implementation, @@ -748,3 +746,59 @@ class CacheExportIntegrationTest(unittest.TestCase): n_static_value_caches = n_static_value_caches + 1 self.assertEqual(n_static_key_caches, model.config.num_hidden_layers) self.assertEqual(n_static_value_caches, model.config.num_hidden_layers) + + # Export with dynamic shapes using Dim.AUTO + tokenizer = AutoTokenizer.from_pretrained(model_id) + input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids + dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None} + exported_program = convert_and_export_with_cache( + model, + example_input_ids=input_ids, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + def test_hybrid_cache_exportability(self): + """ + Tests that static cache works with `torch.export()` + """ + if not is_torch_greater_or_equal("2.6"): + self.skipTest(reason="This test requires torch >= 2.6 to run.") + + from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + + set_seed(0) + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id) + model.eval() + self.assertEqual(model.config.use_cache, True) + self.assertEqual(model.config.cache_implementation, "hybrid") + + # Export + HybridCache + model.eval() + max_batch_size = 1 + max_cache_len = 23 + exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len) + exported_program = exportable_module.export() + n_g_key_caches = n_g_value_caches = 0 + for buffer_name, buffer in exported_program.named_buffers(): + if buffer_name.startswith("key_cache"): + self.assertTrue(buffer.shape[0] == max_batch_size) + self.assertTrue(buffer.shape[2] == max_cache_len) + n_g_key_caches = n_g_key_caches + 1 + if buffer_name.startswith("value_cache"): + self.assertTrue(buffer.shape[0] == max_batch_size) + self.assertTrue(buffer.shape[2] == max_cache_len) + n_g_value_caches = n_g_value_caches + 1 + self.assertEqual(n_g_key_caches, model.config.num_hidden_layers) + self.assertEqual(n_g_value_caches, model.config.num_hidden_layers) + + # Export with dynamic shapes using Dim.AUTO + tokenizer = AutoTokenizer.from_pretrained(model_id) + input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids + dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None} + exported_program = exportable_module.export( + input_ids=input_ids, + dynamic_shapes=dynamic_shapes, + strict=False, + )