Fix Qwen models export with torch 2.7 (#37985)

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang 2025-05-07 00:13:08 -07:00 committed by GitHub
parent 3c0796aaea
commit 0b037fd425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 5 deletions

View File

@ -31,6 +31,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils.import_utils import is_torch_greater_or_equal
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -490,7 +491,8 @@ class Qwen2IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
exported_program = convert_and_export_with_cache(model, strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )

View File

@ -31,6 +31,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils.import_utils import is_torch_greater_or_equal
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -453,17 +454,23 @@ class Qwen3IntegrationTest(unittest.TestCase):
qwen_model = "Qwen/Qwen3-0.6B-Base" qwen_model = "Qwen/Qwen3-0.6B-Base"
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right") tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"] if is_torch_greater_or_equal("2.7.0"):
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unsalted, unsweetened, and unflavored."]
else:
strict = True
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
"input_ids" "input_ids"
].shape[-1] ].shape[-1]
# Load model
device = "cpu" device = "cpu"
dtype = torch.bfloat16 dtype = torch.bfloat16
cache_implementation = "static" cache_implementation = "static"
attn_implementation = "sdpa" attn_implementation = "sdpa"
batch_size = 1 batch_size = 1
# Load model
model = Qwen3ForCausalLM.from_pretrained( model = Qwen3ForCausalLM.from_pretrained(
qwen_model, qwen_model,
device_map=device, device_map=device,
@ -486,7 +493,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
# Static Cache + export # Static Cache + export
exported_program = convert_and_export_with_cache(model) exported_program = convert_and_export_with_cache(model, strict=strict)
ep_generated_ids = TorchExportableModuleWithStaticCache.generate( ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
) )