mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Allow override inputs to export recipe (#37508)
Add option to specify dynamic shapes during export Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
parent
481de7204c
commit
a57274466f
@ -10,6 +10,7 @@
|
||||
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -50,13 +51,21 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if model.config.cache_implementation == "static":
|
||||
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
|
||||
raise ValueError("The model must have caching enabled to be performant.")
|
||||
|
||||
if not hasattr(model.config, "cache_implementation"):
|
||||
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
|
||||
# be used by default, so export will use `StaticCache` by default.
|
||||
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
|
||||
self.model = TorchExportableModuleWithStaticCache(model)
|
||||
elif model.config.cache_implementation == "hybrid":
|
||||
else:
|
||||
if model.config.cache_implementation == "hybrid":
|
||||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
|
||||
f"Unsupported cache implementation: {model.config.cache_implementation}. "
|
||||
"Please use `hybrid` or `static`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -462,6 +471,8 @@ def convert_and_export_with_cache(
|
||||
model: PreTrainedModel,
|
||||
example_input_ids: Optional[torch.Tensor] = None,
|
||||
example_cache_position: Optional[torch.Tensor] = None,
|
||||
dynamic_shapes: Optional[dict] = None,
|
||||
strict: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
|
||||
@ -469,8 +480,10 @@ def convert_and_export_with_cache(
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`): The pretrained model to be exported.
|
||||
example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
|
||||
example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
|
||||
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
|
||||
example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
|
||||
dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
|
||||
strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.
|
||||
|
||||
Returns:
|
||||
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
||||
@ -489,14 +502,21 @@ def convert_and_export_with_cache(
|
||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||
)
|
||||
|
||||
if is_torch_greater_or_equal("2.5.0"):
|
||||
if is_torch_greater_or_equal("2.6.0"):
|
||||
exported_program = torch.export.export(
|
||||
TorchExportableModuleWithStaticCache(model),
|
||||
args=(example_input_ids,),
|
||||
kwargs={"cache_position": example_cache_position},
|
||||
strict=True,
|
||||
args=(example_input_ids, example_cache_position),
|
||||
kwargs={},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=strict if strict is not None else True,
|
||||
)
|
||||
else:
|
||||
if dynamic_shapes is not None:
|
||||
logging.warning(
|
||||
"Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0."
|
||||
)
|
||||
if strict is not None:
|
||||
logging.warning("The strict flag will be ingored by convert_and_export_with_cache for torch < 2.6.0.")
|
||||
# We have to keep this path for BC.
|
||||
#
|
||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||
|
@ -25,7 +25,6 @@ from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_gptq,
|
||||
require_non_xpu,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
@ -693,8 +692,6 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_static_cache_exportability(self):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
@ -709,8 +706,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||
batch_size = 1
|
||||
max_cache_len = 1234
|
||||
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b",
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
attn_implementation=attn_implementation,
|
||||
@ -748,3 +746,59 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
n_static_value_caches = n_static_value_caches + 1
|
||||
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
# Export with dynamic shapes using Dim.AUTO
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
|
||||
exported_program = convert_and_export_with_cache(
|
||||
model,
|
||||
example_input_ids=input_ids,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def test_hybrid_cache_exportability(self):
|
||||
"""
|
||||
Tests that static cache works with `torch.export()`
|
||||
"""
|
||||
if not is_torch_greater_or_equal("2.6"):
|
||||
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||
|
||||
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||
|
||||
set_seed(0)
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
model.eval()
|
||||
self.assertEqual(model.config.use_cache, True)
|
||||
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||
|
||||
# Export + HybridCache
|
||||
model.eval()
|
||||
max_batch_size = 1
|
||||
max_cache_len = 23
|
||||
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
|
||||
exported_program = exportable_module.export()
|
||||
n_g_key_caches = n_g_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("key_cache"):
|
||||
self.assertTrue(buffer.shape[0] == max_batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_g_key_caches = n_g_key_caches + 1
|
||||
if buffer_name.startswith("value_cache"):
|
||||
self.assertTrue(buffer.shape[0] == max_batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_g_value_caches = n_g_value_caches + 1
|
||||
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
|
||||
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
|
||||
|
||||
# Export with dynamic shapes using Dim.AUTO
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
|
||||
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
|
||||
exported_program = exportable_module.export(
|
||||
input_ids=input_ids,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user