From 0b037fd425da1f98ea54e6b63d11e7d223b1e516 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Wed, 7 May 2025 00:13:08 -0700 Subject: [PATCH] Fix Qwen models export with torch 2.7 (#37985) Co-authored-by: Guang Yang --- tests/models/qwen2/test_modeling_qwen2.py | 4 +++- tests/models/qwen3/test_modeling_qwen3.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 1339a09b64f..acb784f6ab2 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -31,6 +31,7 @@ from transformers.testing_utils import ( slow, torch_device, ) +from transformers.utils.import_utils import is_torch_greater_or_equal from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -490,7 +491,8 @@ class Qwen2IntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994 + exported_program = convert_and_export_with_cache(model, strict=strict) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens ) diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 0a5660ecd2c..44eb7474fa8 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -31,6 +31,7 @@ from transformers.testing_utils import ( slow, torch_device, ) +from transformers.utils.import_utils import is_torch_greater_or_equal from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -453,17 +454,23 @@ class Qwen3IntegrationTest(unittest.TestCase): qwen_model = "Qwen/Qwen3-0.6B-Base" tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="", padding_side="right") - EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"] + if is_torch_greater_or_equal("2.7.0"): + strict = False # Due to https://github.com/pytorch/pytorch/issues/150994 + EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unsalted, unsweetened, and unflavored."] + else: + strict = True + EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ "input_ids" ].shape[-1] - - # Load model device = "cpu" dtype = torch.bfloat16 cache_implementation = "static" attn_implementation = "sdpa" batch_size = 1 + + # Load model model = Qwen3ForCausalLM.from_pretrained( qwen_model, device_map=device, @@ -486,7 +493,7 @@ class Qwen3IntegrationTest(unittest.TestCase): max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] # Static Cache + export - exported_program = convert_and_export_with_cache(model) + exported_program = convert_and_export_with_cache(model, strict=strict) ep_generated_ids = TorchExportableModuleWithStaticCache.generate( exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens )