mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[qwen3] fix generation tests (#37142)
* do not skip tests * fix qwen3-moe as well * fixup * fixup
This commit is contained in:
parent
e686fed635
commit
8805600406
@ -352,7 +352,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
def test_Mistral_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
|
@ -351,7 +351,6 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
def test_Mixtral_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
|
@ -363,7 +363,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def test_Qwen2_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
|
@ -391,7 +391,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
|
||||
def test_Qwen2Moe_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
|
@ -62,7 +62,7 @@ class Qwen3ModelTester:
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=5,
|
||||
max_window_layers=3,
|
||||
use_sliding_window=True,
|
||||
@ -348,42 +348,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
self.model_tester = Qwen3ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37)
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_assisted_decoding_matches_greedy_search_1_same(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_generate_compilation_all_outputs(self):
|
||||
pass
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@ -402,7 +366,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
def test_Qwen3_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
@ -461,9 +424,9 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Qwen3 uses GQA on all models so the KV cache is a non standard format")
|
||||
# Ignore copy
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
super().test_past_key_values_format()
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@ -487,7 +450,6 @@ class Qwen3IntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([5.9062, 6.0938, 5.5625, 3.8594, 2.6094, 1.9531, 4.3125, 4.9375, 3.8906, 3.1094, 3.6719, 5.1562, 6.9062, 5.7500, 5.4062, 7.0625, 8.7500, 8.7500, 8.1250, 7.9375, 8.0625, 7.5312, 7.3750, 7.2188, 7.2500, 5.8750, 2.8750, 4.3438, 2.3438, 2.2500]) # fmt: skip
|
||||
print(out[0, 0, :30])
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||
|
||||
del model
|
||||
|
@ -60,7 +60,7 @@ class Qwen3MoeModelTester:
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=5,
|
||||
max_window_layers=3,
|
||||
use_sliding_window=True,
|
||||
@ -367,38 +367,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
self.model_tester = Qwen3MoeModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37)
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_assisted_decoding_matches_greedy_search_0_random(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_assisted_decoding_matches_greedy_search_1_same(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_assisted_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_dola_decoding_sample(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO: ask the contributor to take a look")
|
||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||
pass
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
@ -417,7 +385,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
|
||||
def test_Qwen3Moe_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
print(config)
|
||||
config.num_labels = 3
|
||||
input_ids = input_dict["input_ids"]
|
||||
attention_mask = input_ids.ne(1).to(torch_device)
|
||||
@ -476,9 +443,9 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Qwen3Moe uses GQA on all models so the KV cache is a non standard format")
|
||||
# Ignore copy
|
||||
def test_past_key_values_format(self):
|
||||
pass
|
||||
super().test_past_key_values_format()
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@ -539,7 +506,6 @@ class Qwen3MoeIntegrationTest(unittest.TestCase):
|
||||
torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2)
|
||||
# slicing logits[0, 0, 0:30]
|
||||
EXPECTED_SLICE = torch.tensor([7.5938, 2.6094, 4.0312, 4.0938, 2.5156, 2.7812, 2.9688, 1.5547, 1.3984, 2.2344, 3.0156, 3.1562, 1.1953, 3.2500, 1.0938, 8.4375, 9.5625, 9.0625, 7.5625, 7.5625, 7.9062, 7.2188, 7.0312, 6.9375, 8.0625, 1.7266, 0.9141, 3.7969, 5.3438, 3.9844]) # fmt: skip
|
||||
print(out[0, 0, :30])
|
||||
torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-4, atol=1e-4)
|
||||
|
||||
del model
|
||||
|
Loading…
Reference in New Issue
Block a user