mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix Qwen models export with torch 2.7 (#37985)
Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
parent
3c0796aaea
commit
0b037fd425
@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -490,7 +491,8 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
exported_program = convert_and_export_with_cache(model, strict=strict)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -453,17 +454,23 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
qwen_model = "Qwen/Qwen3-0.6B-Base"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
|
||||
if is_torch_greater_or_equal("2.7.0"):
|
||||
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unsalted, unsweetened, and unflavored."]
|
||||
else:
|
||||
strict = True
|
||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
|
||||
|
||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||
"input_ids"
|
||||
].shape[-1]
|
||||
|
||||
# Load model
|
||||
device = "cpu"
|
||||
dtype = torch.bfloat16
|
||||
cache_implementation = "static"
|
||||
attn_implementation = "sdpa"
|
||||
batch_size = 1
|
||||
|
||||
# Load model
|
||||
model = Qwen3ForCausalLM.from_pretrained(
|
||||
qwen_model,
|
||||
device_map=device,
|
||||
@ -486,7 +493,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||
|
||||
# Static Cache + export
|
||||
exported_program = convert_and_export_with_cache(model)
|
||||
exported_program = convert_and_export_with_cache(model, strict=strict)
|
||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user