Fix Qwen models export with torch 2.7 (#37985)

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang 2025-05-07 00:13:08 -07:00 committed by GitHub
parent 3c0796aaea
commit 0b037fd425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 5 deletions

View File

@ -31,6 +31,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils.import_utils import is_torch_greater_or_equal
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -490,7 +491,8 @@ class Qwen2IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export
exported_program = convert_and_export_with_cache(model)
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = convert_and_export_with_cache(model, strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)

View File

@ -31,6 +31,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils.import_utils import is_torch_greater_or_equal
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -453,17 +454,23 @@ class Qwen3IntegrationTest(unittest.TestCase):
qwen_model = "Qwen/Qwen3-0.6B-Base"
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
if is_torch_greater_or_equal("2.7.0"):
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unsalted, unsweetened, and unflavored."]
else:
strict = True
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids"
].shape[-1]
# Load model
device = "cpu"
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"
batch_size = 1
# Load model
model = Qwen3ForCausalLM.from_pretrained(
qwen_model,
device_map=device,
@ -486,7 +493,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export
exported_program = convert_and_export_with_cache(model)
exported_program = convert_and_export_with_cache(model, strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
)