mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
Fix Qwen models export with torch 2.7 (#37985)
Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
parent
3c0796aaea
commit
0b037fd425
@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -490,7 +491,8 @@ class Qwen2IntegrationTest(unittest.TestCase):
|
|||||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||||
|
|
||||||
# Static Cache + export
|
# Static Cache + export
|
||||||
exported_program = convert_and_export_with_cache(model)
|
strict = is_torch_greater_or_equal("2.7.0") # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||||
|
exported_program = convert_and_export_with_cache(model, strict=strict)
|
||||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||||
)
|
)
|
||||||
|
@ -31,6 +31,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from transformers.utils.import_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -453,17 +454,23 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
qwen_model = "Qwen/Qwen3-0.6B-Base"
|
qwen_model = "Qwen/Qwen3-0.6B-Base"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
tokenizer = AutoTokenizer.from_pretrained(qwen_model, pad_token="</s>", padding_side="right")
|
||||||
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
|
if is_torch_greater_or_equal("2.7.0"):
|
||||||
|
strict = False # Due to https://github.com/pytorch/pytorch/issues/150994
|
||||||
|
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unsalted, unsweetened, and unflavored."]
|
||||||
|
else:
|
||||||
|
strict = True
|
||||||
|
EXPECTED_TEXT_COMPLETION = ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"]
|
||||||
|
|
||||||
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
].shape[-1]
|
].shape[-1]
|
||||||
|
|
||||||
# Load model
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa"
|
attn_implementation = "sdpa"
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
|
# Load model
|
||||||
model = Qwen3ForCausalLM.from_pretrained(
|
model = Qwen3ForCausalLM.from_pretrained(
|
||||||
qwen_model,
|
qwen_model,
|
||||||
device_map=device,
|
device_map=device,
|
||||||
@ -486,7 +493,7 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
|||||||
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
max_new_tokens = max_generation_length - prompt_token_ids.shape[-1]
|
||||||
|
|
||||||
# Static Cache + export
|
# Static Cache + export
|
||||||
exported_program = convert_and_export_with_cache(model)
|
exported_program = convert_and_export_with_cache(model, strict=strict)
|
||||||
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
ep_generated_ids = TorchExportableModuleWithStaticCache.generate(
|
||||||
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user