[qwen3] fix generation tests (#37142)

* do not skip tests

* fix qwen3-moe as well

* fixup

* fixup
This commit is contained in:
Raushan Turganbay 2025-03-31 16:33:41 +02:00 committed by GitHub
parent e686fed635
commit 8805600406
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 6 additions and 82 deletions

View File

@ -352,7 +352,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_Mistral_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)

View File

@ -351,7 +351,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_Mixtral_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)

View File

@ -363,7 +363,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_Qwen2_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)

View File

@ -391,7 +391,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_Qwen2Moe_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)

View File

@ -62,7 +62,7 @@ class Qwen3ModelTester:
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=64,
num_hidden_layers=5,
max_window_layers=3,
use_sliding_window=True,
@ -348,42 +348,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.model_tester = Qwen3ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37)
@unittest.skip("TODO: ask the contributor to take a look")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_assisted_decoding_matches_greedy_search_1_same(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_dola_decoding_sample(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_generate_compilation_all_outputs(self):
pass
def test_config(self):
self.config_tester.run_common_tests()
@ -402,7 +366,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_Qwen3_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
@ -461,9 +424,9 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="Qwen3 uses GQA on all models so the KV cache is a non standard format")
# Ignore copy
def test_past_key_values_format(self):
pass
super().test_past_key_values_format()
@require_flash_attn
@require_torch_gpu
@ -487,7 +450,6 @@ class Qwen3IntegrationTest(unittest.TestCase):
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
print(out[0, 0, :30])
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
del model

View File

@ -60,7 +60,7 @@ class Qwen3MoeModelTester:
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
hidden_size=64,
num_hidden_layers=5,
max_window_layers=3,
use_sliding_window=True,
@ -367,38 +367,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
self.model_tester = Qwen3MoeModelTester(self)
self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37)
@unittest.skip("TODO: ask the contributor to take a look")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_assisted_decoding_matches_greedy_search_0_random(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_assisted_decoding_matches_greedy_search_1_same(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_assisted_decoding_sample(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_dola_decoding_sample(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip("TODO: ask the contributor to take a look")
def test_prompt_lookup_decoding_matches_greedy_search(self):
pass
def test_config(self):
self.config_tester.run_common_tests()
@ -417,7 +385,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_Qwen3Moe_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
print(config)
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
@ -476,9 +443,9 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="Qwen3Moe uses GQA on all models so the KV cache is a non standard format")
# Ignore copy
def test_past_key_values_format(self):
pass
super().test_past_key_values_format()
@require_flash_attn
@require_torch_gpu
@ -539,7 +506,6 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip
print(out[0, 0, :30])
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
del model