mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Allow override inputs to export recipe (#37508)
Add option to specify dynamic shapes during export Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
parent
481de7204c
commit
a57274466f
@ -10,6 +10,7 @@
|
|||||||
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations under the License.
|
# specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -50,13 +51,21 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if model.config.cache_implementation == "static":
|
if not hasattr(model.config, "use_cache") or model.config.use_cache is False:
|
||||||
|
raise ValueError("The model must have caching enabled to be performant.")
|
||||||
|
|
||||||
|
if not hasattr(model.config, "cache_implementation"):
|
||||||
|
# If `cache_implementation` is not specified explicitly in the config, `DynamicCache` will
|
||||||
|
# be used by default, so export will use `StaticCache` by default.
|
||||||
|
logging.info("Using `StaticCache` for export as `cache_implementation` is not specified in the config.")
|
||||||
self.model = TorchExportableModuleWithStaticCache(model)
|
self.model = TorchExportableModuleWithStaticCache(model)
|
||||||
elif model.config.cache_implementation == "hybrid":
|
else:
|
||||||
|
if model.config.cache_implementation == "hybrid":
|
||||||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported cache implementation in this export recipe: '{model.config.cache_implementation}'"
|
f"Unsupported cache implementation: {model.config.cache_implementation}. "
|
||||||
|
"Please use `hybrid` or `static`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -462,6 +471,8 @@ def convert_and_export_with_cache(
|
|||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
example_input_ids: Optional[torch.Tensor] = None,
|
example_input_ids: Optional[torch.Tensor] = None,
|
||||||
example_cache_position: Optional[torch.Tensor] = None,
|
example_cache_position: Optional[torch.Tensor] = None,
|
||||||
|
dynamic_shapes: Optional[dict] = None,
|
||||||
|
strict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
|
Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
|
||||||
@ -469,8 +480,10 @@ def convert_and_export_with_cache(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (`PreTrainedModel`): The pretrained model to be exported.
|
model (`PreTrainedModel`): The pretrained model to be exported.
|
||||||
example_input_ids (`torch.Tensor`): Example input token id used by `torch.export`.
|
example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
|
||||||
example_cache_position (`torch.Tensor`): Example current cache position used by `torch.export`.
|
example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
|
||||||
|
dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
|
||||||
|
strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
|
||||||
@ -489,14 +502,21 @@ def convert_and_export_with_cache(
|
|||||||
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
example_cache_position if example_cache_position is not None else torch.tensor([0], dtype=torch.long)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_torch_greater_or_equal("2.5.0"):
|
if is_torch_greater_or_equal("2.6.0"):
|
||||||
exported_program = torch.export.export(
|
exported_program = torch.export.export(
|
||||||
TorchExportableModuleWithStaticCache(model),
|
TorchExportableModuleWithStaticCache(model),
|
||||||
args=(example_input_ids,),
|
args=(example_input_ids, example_cache_position),
|
||||||
kwargs={"cache_position": example_cache_position},
|
kwargs={},
|
||||||
strict=True,
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=strict if strict is not None else True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if dynamic_shapes is not None:
|
||||||
|
logging.warning(
|
||||||
|
"Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0."
|
||||||
|
)
|
||||||
|
if strict is not None:
|
||||||
|
logging.warning("The strict flag will be ingored by convert_and_export_with_cache for torch < 2.6.0.")
|
||||||
# We have to keep this path for BC.
|
# We have to keep this path for BC.
|
||||||
#
|
#
|
||||||
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
|
||||||
|
@ -25,7 +25,6 @@ from transformers.testing_utils import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
require_gptq,
|
require_gptq,
|
||||||
require_non_xpu,
|
require_non_xpu,
|
||||||
require_read_token,
|
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
@ -693,8 +692,6 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||||
self.assertTrue(torch.allclose(v1, v2))
|
self.assertTrue(torch.allclose(v1, v2))
|
||||||
|
|
||||||
@slow
|
|
||||||
@require_read_token
|
|
||||||
def test_static_cache_exportability(self):
|
def test_static_cache_exportability(self):
|
||||||
"""
|
"""
|
||||||
Tests that static cache works with `torch.export()`
|
Tests that static cache works with `torch.export()`
|
||||||
@ -709,8 +706,9 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
max_cache_len = 1234
|
max_cache_len = 1234
|
||||||
|
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"google/gemma-2b",
|
model_id,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
@ -748,3 +746,59 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
n_static_value_caches = n_static_value_caches + 1
|
n_static_value_caches = n_static_value_caches + 1
|
||||||
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
|
||||||
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
|
||||||
|
|
||||||
|
# Export with dynamic shapes using Dim.AUTO
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
|
||||||
|
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
|
||||||
|
exported_program = convert_and_export_with_cache(
|
||||||
|
model,
|
||||||
|
example_input_ids=input_ids,
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hybrid_cache_exportability(self):
|
||||||
|
"""
|
||||||
|
Tests that static cache works with `torch.export()`
|
||||||
|
"""
|
||||||
|
if not is_torch_greater_or_equal("2.6"):
|
||||||
|
self.skipTest(reason="This test requires torch >= 2.6 to run.")
|
||||||
|
|
||||||
|
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
|
||||||
|
|
||||||
|
set_seed(0)
|
||||||
|
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||||
|
model.eval()
|
||||||
|
self.assertEqual(model.config.use_cache, True)
|
||||||
|
self.assertEqual(model.config.cache_implementation, "hybrid")
|
||||||
|
|
||||||
|
# Export + HybridCache
|
||||||
|
model.eval()
|
||||||
|
max_batch_size = 1
|
||||||
|
max_cache_len = 23
|
||||||
|
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, max_batch_size, max_cache_len)
|
||||||
|
exported_program = exportable_module.export()
|
||||||
|
n_g_key_caches = n_g_value_caches = 0
|
||||||
|
for buffer_name, buffer in exported_program.named_buffers():
|
||||||
|
if buffer_name.startswith("key_cache"):
|
||||||
|
self.assertTrue(buffer.shape[0] == max_batch_size)
|
||||||
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||||
|
n_g_key_caches = n_g_key_caches + 1
|
||||||
|
if buffer_name.startswith("value_cache"):
|
||||||
|
self.assertTrue(buffer.shape[0] == max_batch_size)
|
||||||
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||||
|
n_g_value_caches = n_g_value_caches + 1
|
||||||
|
self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
|
||||||
|
self.assertEqual(n_g_value_caches, model.config.num_hidden_layers)
|
||||||
|
|
||||||
|
# Export with dynamic shapes using Dim.AUTO
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
input_ids = tokenizer("Here's everything I know", return_tensors="pt").input_ids
|
||||||
|
dynamic_shapes = {"input_ids": {1: torch.export.Dim.AUTO}, "cache_position": None}
|
||||||
|
exported_program = exportable_module.export(
|
||||||
|
input_ids=input_ids,
|
||||||
|
dynamic_shapes=dynamic_shapes,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user