From 53fb245eb60364c7377c5f37fc37807a00e9b2e2 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 May 2025 18:29:31 +0100 Subject: [PATCH] :rotating_light: :rotating_light: Inherited CausalLM Tests (#37590) * stash commit * Experiment 1: Try just Gemma * Experiment 1: Just try Gemma * make fixup * Trigger tests * stash commit * Try adding Gemma3 as well * make fixup * Correct attrib names * Correct pipeline model mapping * Add in all_model_classes for Gemma1 again * Move the pipeline model mapping around again * make fixup * Revert Gemma3 changes since it's a VLM * Let's try Falcon * Correct attributes * Correct attributes * Let's try just overriding get_config() for now * Do Nemotron too * And Llama! * Do llama/persimmon * Correctly skip tests * Fix Persimmon * Include Phimoe * Fix Gemma2 * Set model_tester_class correctly * Add GLM * More models! * models models models * make fixup * Add Qwen3 + Qwen3MoE * Correct import * make fixup * Add the QuestionAnswering classes * Add the QuestionAnswering classes * Move pipeline mapping to the right place * Jetmoe too * Stop RoPE testing models with no RoPE * Fix up JetMOE a bit * Fix up JetMOE a bit * Can we just force pad_token_id all the time? * make fixup * fix starcoder2 * Move pipeline mapping * Fix RoPE skipping * Fix RecurrentGemma tests * Fix Falcon tests * Add MoE attributes * Fix values for RoPE testing * Make sure we set bos_token_id and eos_token_id in an appropriate range * make fixup * Fix GLM4 * Add mamba attributes * Revert bits of JetMOE * Re-add the JetMOE skips * Update tests/causal_lm_tester.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add licence --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- tests/causal_lm_tester.py | 479 ++++++++++++++++++ tests/models/dbrx/test_modeling_dbrx.py | 2 - tests/models/falcon/test_modeling_falcon.py | 277 +--------- tests/models/gemma/test_modeling_gemma.py | 257 +--------- tests/models/gemma2/test_modeling_gemma2.py | 27 +- tests/models/glm/test_modeling_glm.py | 248 +-------- tests/models/glm4/test_modeling_glm4.py | 20 +- .../models/gpt_neox/test_modeling_gpt_neox.py | 1 - tests/models/jetmoe/test_modeling_jetmoe.py | 167 +----- tests/models/llama/test_modeling_llama.py | 358 +------------ tests/models/mistral/test_modeling_mistral.py | 214 +------- tests/models/mixtral/test_modeling_mixtral.py | 223 +------- .../models/nemotron/test_modeling_nemotron.py | 55 +- .../persimmon/test_modeling_persimmon.py | 301 +---------- tests/models/phi/test_modeling_phi.py | 277 +--------- tests/models/phi3/test_modeling_phi3.py | 285 +---------- tests/models/phimoe/test_modeling_phimoe.py | 285 +---------- tests/models/qwen2/test_modeling_qwen2.py | 230 +-------- .../qwen2_moe/test_modeling_qwen2_moe.py | 257 +--------- tests/models/qwen3/test_modeling_qwen3.py | 234 +-------- .../qwen3_moe/test_modeling_qwen3_moe.py | 256 +--------- .../test_modeling_recurrent_gemma.py | 248 +++------ .../models/stablelm/test_modeling_stablelm.py | 303 +---------- .../starcoder2/test_modeling_starcoder2.py | 232 +-------- tests/test_modeling_common.py | 2 +- 25 files changed, 816 insertions(+), 4422 deletions(-) create mode 100644 tests/causal_lm_tester.py diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py new file mode 100644 index 00000000000..2ef760e0fbe --- /dev/null +++ b/tests/causal_lm_tester.py @@ -0,0 +1,479 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from inspect import signature + +import pytest +from parameterized import parameterized + +from transformers import set_seed +from transformers.testing_utils import ( + is_flaky, + require_flash_attn, + require_torch_accelerator, + require_torch_gpu, + require_torch_sdpa, + slow, +) + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ( + GenerationTesterMixin, + ModelTesterMixin, + ids_tensor, + is_torch_available, + require_torch, + torch_device, +) +from .test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + +class CausalLMModelTester: + _required_attributes = ("base_model_class", "config_class", "causal_lm_class") + forced_config_args = [ + "pad_token_id" + ] # Arguments that should be passed to the config class even if not in its signature + config_class = None + base_model_class = None + causal_lm_class = None + sequence_classification_class = None + token_classification_class = None + question_answering_class = None + + def _verify_model_attributes(self): + for required_attribute in self._required_attributes: + if getattr(self, required_attribute) is None: + raise ValueError( + f"You have inherited from CausalLMModelTester but did not set the {required_attribute} attribute." + ) + + @property + def all_model_classes(self): + return [ + model_class + for model_class in ( + self.base_model_class, + self.causal_lm_class, + self.sequence_classification_class, + self.token_classification_class, + ) + if model_class is not None + ] + + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + is_decoder=False, + scope=None, + expert_interval=1, + moe_intermediate_size=12, + shared_expert_intermediate_size=36, + shared_expert_gate=True, + num_experts_per_tok=2, + num_experts=8, + mamba_n_groups=1, + mamba_n_heads=16, + mamba_d_state=16, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=16, + ): + self._verify_model_attributes() + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.scope = scope + self.head_dim = self.hidden_size // self.num_attention_heads + self.is_decoder = is_decoder + self.expert_interval = expert_interval + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.shared_expert_gate = shared_expert_gate + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.mamba_n_groups = mamba_n_groups + self.mamba_n_heads = mamba_n_heads + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + kwarg_names = list(signature(self.config_class.__init__).parameters.keys()) + kwargs = { + k: getattr(self, k) for k in kwarg_names + self.forced_config_args if hasattr(self, k) and k != "self" + } + return self.config_class(**kwargs) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = self.base_model_class(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin): + test_headmasking = False + test_pruning = False + model_tester_class = None + all_model_classes = None + rotary_embedding_layer = None # Enables RoPE tests if set + pipeline_model_mapping = None + + def setUp(self): + if self.model_tester_class is None: + raise ValueError( + "You have inherited from CausalLMModelTest but did not set the model_tester_class attribute." + ) + self.model_tester = self.model_tester_class(self) + self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class) + if self.all_model_classes is None: + self.all_model_classes = self.model_tester.all_model_classes + if self.pipeline_model_mapping is None: + raise ValueError( + "You have inherited from CausalLMModelTest but did not set the pipeline_model_mapping attribute." + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_sequence_classification_model(self): + if self.model_tester.sequence_classification_class is None: + self.skipTest("Model does not support sequence classification") + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = self.model_tester.sequence_classification_class(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_sequence_classification_model_for_single_label(self): + if self.model_tester.sequence_classification_class is None: + self.skipTest("Model does not support sequence classification") + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = self.model_tester.sequence_classification_class(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_sequence_classification_model_for_multi_label(self): + if self.model_tester.sequence_classification_class is None: + self.skipTest("Model does not support sequence classification") + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = self.model_tester.sequence_classification_class(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_token_classification_model(self): + if self.model_tester.token_classification_class is None: + self.skipTest("Model does not support token classification") + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) + model = self.model_tester.token_classification_class(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=token_labels) + self.assertEqual( + result.logits.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), + ) + + @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + if self.rotary_embedding_layer is None: + self.skipTest("Rotary embedding layer not set") + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = self.model_tester_class.base_model_class(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = self.model_tester_class.base_model_class(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + def test_model_rope_scaling(self): + if self.rotary_embedding_layer is None: + self.skipTest("Rotary embedding layer not set") + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn( + 1, dtype=torch.float32, device=torch_device + ) # used exclusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) + + # Sanity check original RoPE + original_rope = self.rotary_embedding_layer(config=config).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + + # Sanity check Yarn RoPE scaling + # Scaling should be over the entire input + config.rope_scaling = {"type": "yarn", "factor": scaling_factor} + yarn_scaling_rope = self.rotary_embedding_layer(config=config).to(torch_device) + yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) + yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) + torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_short, original_cos_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(yarn_sin_long, original_sin_long) + + @require_torch_sdpa + @require_torch_accelerator + @slow + def test_sdpa_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(reason="Model does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="sdpa" + ) + model_sdpa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_sdpa = outputs_sdpa.hidden_states[-1] + + assert torch.allclose(logits_sdpa, logits, atol=2e-3) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @is_flaky() + @slow + def test_flash_attn_2_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(reason="Model does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa, logits, atol=2e-3) diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index 512bd6a02c0..7b6cfc5b081 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -179,7 +179,6 @@ class DbrxModelTester: ) return config - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Dbrx def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -190,7 +189,6 @@ class DbrxModelTester: result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Dbrx def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 6a63177476b..661ba98cf16 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -15,14 +15,11 @@ import unittest -from parameterized import parameterized - from transformers import ( AutoModelForCausalLM, AutoTokenizer, FalconConfig, is_torch_available, - set_seed, ) from transformers.testing_utils import ( require_bitsandbytes, @@ -32,10 +29,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -48,126 +42,24 @@ if is_torch_available(): FalconForTokenClassification, FalconModel, ) - from transformers.models.falcon.modeling_falcon import ( - FalconRotaryEmbedding, - ) -class FalconModelTester: - def __init__( - self, - parent, - batch_size=3, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope +class FalconModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = FalconConfig + base_model_class = FalconModel + causal_lm_class = FalconForCausalLM + sequence_class = FalconForSequenceClassification + token_class = FalconForTokenClassification - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = random_attention_mask([self.batch_size, self.seq_length]) - - token_type_ids = None - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return FalconConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=1, - new_decoder_architecture=True, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = FalconModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + def __init__(self, parent, new_decoder_architecture=True): + super().__init__(parent) + self.new_decoder_architecture = new_decoder_architecture @require_torch -class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class FalconModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = FalconModelTester all_model_classes = ( ( FalconModel, @@ -182,10 +74,9 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix pipeline_model_mapping = ( { "feature-extraction": FalconModel, - "question-answering": FalconForQuestionAnswering, "text-classification": FalconForSequenceClassification, - "text-generation": FalconForCausalLM, "token-classification": FalconForTokenClassification, + "text-generation": FalconForCausalLM, "zero-shot": FalconForSequenceClassification, } if is_torch_available() @@ -207,146 +98,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ): return True - def setUp(self): - self.model_tester = FalconModelTester(self) - self.config_tester = ConfigTester(self, config_class=FalconConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_position_embedding_types(self): - config, *inputs = self.model_tester.prepare_config_and_inputs() - for alibi in [True, False]: - config.alibi = alibi - self.model_tester.create_and_check_model(config, *inputs) - - def test_falcon_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = FalconForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_falcon_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = FalconForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_falcon_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = FalconForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = FalconModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = FalconModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = FalconRotaryEmbedding(config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = FalconRotaryEmbedding(config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = FalconRotaryEmbedding(config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - @require_torch class FalconLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 373b3ffbe22..649c837b9c2 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -33,10 +33,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -51,138 +48,17 @@ if is_torch_available(): @require_torch -class GemmaModelTester: +class GemmaModelTester(CausalLMModelTester): config_class = GemmaConfig if is_torch_available(): - model_class = GemmaModel - for_causal_lm_class = GemmaForCausalLM - for_sequence_class = GemmaForSequenceClassification - for_token_class = GemmaForTokenClassification - - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - self.head_dim = self.hidden_size // self.num_attention_heads - - # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return self.config_class( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - head_dim=self.head_dim, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = self.model_class(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Gemma - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + base_model_class = GemmaModel + causal_lm_class = GemmaForCausalLM + sequence_classification_class = GemmaForSequenceClassification + token_classification_class = GemmaForTokenClassification @require_torch -class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class GemmaModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (GemmaModel, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification) if is_torch_available() @@ -199,12 +75,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi if is_torch_available() else {} ) - test_headmasking = False - test_pruning = False - - # Need to remove 0.9 in `test_cpu_offload` - # This is because we are hitting edge cases with the causal_mask buffer - model_split_percents = [0.5, 0.6] + model_tester_class = GemmaModelTester # used in `test_torch_compile_for_training` _torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None @@ -222,78 +93,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ): return True - def setUp(self): - self.model_tester = GemmaModelTester(self) - self.config_tester = ConfigTester(self, config_class=GemmaConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_Gemma_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = self.model_tester.for_sequence_class(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Gemma_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = self.model_tester.for_sequence_class(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Gemma_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = self.model_tester.for_sequence_class(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Gemma_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = self.model_tester.for_token_class(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -301,46 +100,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Gemma flash attention does not support right padding") - @require_torch_sdpa - @require_torch_accelerator - def test_sdpa_equivalence(self): - for model_class in self.all_model_classes: - if not model_class._supports_sdpa: - self.skipTest(reason="Model does not support SDPA") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config).to(torch_device) - dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) - - model.config._attn_implementation = "sdpa" - states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[-1] - - model.config._attn_implementation = "eager" - states_eager = model(dummy_input, output_hidden_states=True).hidden_states[-1] - - torch.testing.assert_close(states_sdpa, states_eager, atol=1e-5, rtol=1e-5) - - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - def test_flash_attn_2_equivalence(self): - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(reason="Model does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config).to(device=torch_device, dtype=torch.float16) - dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) - - model.config._attn_implementation = "flash_attention_2" - states_sdpa = model(dummy_input, output_hidden_states=True).hidden_states[1] - - model.config._attn_implementation = "eager" - states_eager = model(dummy_input, output_hidden_states=True).hidden_states[1] - - # Here we use higher tolerance and the output of the 2nd layer because otherwise small diffs add-up - torch.testing.assert_close(states_sdpa, states_eager, atol=1e-3, rtol=1e-3) - @slow @require_torch_accelerator diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 2561875f387..cb98a5a0e69 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -33,7 +33,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_configuration_common import ConfigTester @@ -48,17 +48,28 @@ if is_torch_available(): ) -class Gemma2ModelTester(GemmaModelTester): +class Gemma2ModelTester(CausalLMModelTester): if is_torch_available(): config_class = Gemma2Config - model_class = Gemma2Model - for_causal_lm_class = Gemma2ForCausalLM - for_sequence_class = Gemma2ForSequenceClassification - for_token_class = Gemma2ForTokenClassification + base_model_class = Gemma2Model + causal_lm_class = Gemma2ForCausalLM + sequence_class = Gemma2ForSequenceClassification + token_class = Gemma2ForTokenClassification + pipeline_model_mapping = ( + { + "feature-extraction": Gemma2Model, + "text-classification": Gemma2ForSequenceClassification, + "token-classification": Gemma2ForTokenClassification, + "text-generation": Gemma2ForCausalLM, + "zero-shot": Gemma2ForSequenceClassification, + } + if is_torch_available() + else {} + ) @require_torch -class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): +class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (Gemma2Model, Gemma2ForCausalLM, Gemma2ForSequenceClassification, Gemma2ForTokenClassification) if is_torch_available() @@ -75,10 +86,12 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): if is_torch_available() else {} ) + test_headmasking = False test_pruning = False _is_stateful = True model_split_percents = [0.5, 0.6] + model_tester_class = Gemma2ModelTester def setUp(self): self.model_tester = Gemma2ModelTester(self) diff --git a/tests/models/glm/test_modeling_glm.py b/tests/models/glm/test_modeling_glm.py index 9e8eda5cb23..e246ea867a0 100644 --- a/tests/models/glm/test_modeling_glm.py +++ b/tests/models/glm/test_modeling_glm.py @@ -19,7 +19,6 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer, GlmConfig, is_torch_available from transformers.testing_utils import ( - is_flaky, require_flash_attn, require_torch, require_torch_large_accelerator, @@ -28,10 +27,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -46,133 +42,17 @@ if is_torch_available(): @require_torch -class GlmModelTester: +class GlmModelTester(CausalLMModelTester): config_class = GlmConfig if is_torch_available(): - model_class = GlmModel - for_causal_lm_class = GlmForCausalLM - for_sequence_class = GlmForSequenceClassification - for_token_class = GlmForTokenClassification - - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="silu", - attention_dropout=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.attention_dropout = attention_dropout - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - self.head_dim = self.hidden_size // self.num_attention_heads - - # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return self.config_class( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - attention_dropout=self.attention_dropout, - max_position_embeddings=self.max_position_embeddings, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - head_dim=self.head_dim, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = self.model_class(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Glm - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + base_model_class = GlmModel + causal_lm_class = GlmForCausalLM + sequence_class = GlmForSequenceClassification + token_class = GlmForTokenClassification @require_torch -class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class GlmModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (GlmModel, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification) if is_torch_available() @@ -188,120 +68,10 @@ class GlmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, if is_torch_available() else {} ) + test_headmasking = False test_pruning = False - - def setUp(self): - self.model_tester = GlmModelTester(self) - self.config_tester = ConfigTester(self, config_class=GlmConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_Glm_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - print(config) - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = self.model_tester.for_sequence_class(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Glm_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = self.model_tester.for_sequence_class(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Glm_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = self.model_tester.for_sequence_class(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Glm_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = self.model_tester.for_token_class(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @is_flaky() - def test_custom_4d_attention_mask(self): - """Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky.""" - for model_class in self.all_generative_model_classes: - if not model_class._supports_static_cache: - self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0: - self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") - model = model_class(config).to(device=torch_device, dtype=torch.float32) - - ( - input_ids, - position_ids, - input_ids_shared_prefix, - mask_shared_prefix, - position_ids_shared_prefix, - ) = self._get_custom_4d_mask_test_data() - - logits = model.forward(input_ids, position_ids=position_ids).logits - # logits.shape == torch.Size([3, 4, ...]) - - logits_shared_prefix = model( - input_ids_shared_prefix, - attention_mask=mask_shared_prefix, - position_ids=position_ids_shared_prefix, - )[0] - # logits_shared_prefix.shape == torch.Size([1, 6, ...]) - - out_last_tokens = logits[:, -1, :] # last tokens in each batch line - out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens - - # comparing softmax-normalized logits: - normalized_0 = torch.nn.functional.softmax(out_last_tokens) - normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens) - print(torch.abs(normalized_0 - normalized_1).max()) - - torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3) + model_tester_class = GlmModelTester @slow diff --git a/tests/models/glm4/test_modeling_glm4.py b/tests/models/glm4/test_modeling_glm4.py index 547b696867d..295954fe20c 100644 --- a/tests/models/glm4/test_modeling_glm4.py +++ b/tests/models/glm4/test_modeling_glm4.py @@ -28,8 +28,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester -from ...test_configuration_common import ConfigTester +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -43,17 +42,18 @@ if is_torch_available(): ) -class Glm4ModelTester(GemmaModelTester): +class Glm4ModelTester(CausalLMModelTester): if is_torch_available(): config_class = Glm4Config - model_class = Glm4Model - for_causal_lm_class = Glm4ForCausalLM - for_sequence_class = Glm4ForSequenceClassification - for_token_class = Glm4ForTokenClassification + base_model_class = Glm4Model + causal_lm_class = Glm4ForCausalLM + sequence_classification_class = Glm4ForSequenceClassification + token_classification_class = Glm4ForTokenClassification @require_torch -class Glm4ModelTest(GemmaModelTest, unittest.TestCase): +class Glm4ModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = Glm4ModelTester all_model_classes = ( (Glm4Model, Glm4ForCausalLM, Glm4ForSequenceClassification, Glm4ForTokenClassification) if is_torch_available() @@ -75,10 +75,6 @@ class Glm4ModelTest(GemmaModelTest, unittest.TestCase): _is_stateful = True model_split_percents = [0.5, 0.6] - def setUp(self): - self.model_tester = Glm4ModelTester(self) - self.config_tester = ConfigTester(self, config_class=Glm4Config, hidden_size=37) - @slow @require_torch_large_gpu diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 33c79f2a7b1..b0a0a6a3ccb 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -341,7 +341,6 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi pass @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index 0dfc7e2cef9..7dd6ca728af 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -28,10 +28,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -44,7 +41,14 @@ if is_torch_available(): ) -class JetMoeModelTester: +class JetMoeModelTester(CausalLMModelTester): + config_class = JetMoeConfig + forced_config_args = ["pad_token_id"] + if is_torch_available(): + base_model_class = JetMoeModel + causal_lm_class = JetMoeForCausalLM + sequence_class = JetMoeForSequenceClassification + def __init__( self, parent, @@ -72,6 +76,7 @@ class JetMoeModelTester: pad_token_id=0, scope=None, ): + super().__init__(parent) self.parent = parent self.batch_size = batch_size self.seq_length = seq_length @@ -98,158 +103,28 @@ class JetMoeModelTester: self.pad_token_id = pad_token_id self.scope = scope - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.ones(self.batch_size, self.seq_length).to(torch_device) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return JetMoeConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_key_value_heads=self.num_key_value_heads, - kv_channels=self.kv_channels, - intermediate_size=self.intermediate_size, - activation_function=self.hidden_act, - num_local_experts=self.num_local_experts, - num_experts_per_tok=self.num_experts_per_tok, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = JetMoeModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict - @require_torch -class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class JetMoeModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (JetMoeModel, JetMoeForCausalLM, JetMoeForSequenceClassification) if is_torch_available() else () ) - pipeline_model_mapping = ( - { - "feature-extraction": JetMoeModel, - "text-classification": JetMoeForSequenceClassification, - "text-generation": JetMoeForCausalLM, - "zero-shot": JetMoeForSequenceClassification, - } - if is_torch_available() - else {} - ) test_headmasking = False test_pruning = False test_mismatched_shapes = False test_cpu_offload = False test_disk_offload_bin = False test_disk_offload_safetensors = False - - def setUp(self): - self.model_tester = JetMoeModelTester(self) - self.config_tester = ConfigTester( - self, config_class=JetMoeConfig, common_properties=["hidden_size", "num_hidden_layers"] - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config - def test_config(self): - self.config_tester.run_common_tests() - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with llama->jetmoe, Llama->JetMoe - def test_jetmoe_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = JetMoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with llama->jetmoe, Llama->JetMoe - def test_jetmoe_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = JetMoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with llama->jetmoe, Llama->JetMoe - def test_jetmoe_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = JetMoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + model_tester_class = JetMoeModelTester + pipeline_model_mapping = ( + { + "feature-extraction": JetMoeModel, + "text-classification": JetMoeForSequenceClassification, + "text-generation": JetMoeForCausalLM, + } + if is_torch_available() + else {} + ) @require_flash_attn @require_torch_gpu diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c9b86c128ca..a1e6c944470 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -16,9 +16,8 @@ import unittest from packaging import version -from parameterized import parameterized -from transformers import AutoTokenizer, LlamaConfig, StaticCache, is_torch_available, set_seed +from transformers import AutoTokenizer, StaticCache, is_torch_available from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( Expectations, @@ -30,16 +29,14 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): import torch from transformers import ( + LlamaConfig, LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, @@ -50,124 +47,17 @@ if is_torch_available(): from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding -class LlamaModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return LlamaConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = LlamaModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class LlamaModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = LlamaConfig + base_model_class = LlamaModel + causal_lm_class = LlamaForCausalLM + sequence_class = LlamaForSequenceClassification + token_class = LlamaForTokenClassification @require_torch -class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class LlamaModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( LlamaModel, @@ -194,6 +84,8 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_headmasking = False test_pruning = False fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = LlamaModelTester + rotary_embedding_layer = LlamaRotaryEmbedding # Enables RoPE tests if set # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer @@ -202,230 +94,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # used in `test_torch_compile_for_training` _torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None - def setUp(self): - self.model_tester = LlamaModelTester(self) - self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_llama_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = LlamaForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_llama_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = LlamaForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_llama_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = LlamaForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_llama_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = LlamaForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @parameterized.expand([("linear",), ("dynamic",), ("yarn",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = LlamaModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = LlamaModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - - # Sanity check Yarn RoPE scaling - # Scaling should be over the entire input - config.rope_scaling = {"type": "yarn", "factor": scaling_factor} - yarn_scaling_rope = LlamaRotaryEmbedding(config=config).to(torch_device) - yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short) - yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long) - torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :]) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_short, original_cos_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(yarn_sin_long, original_sin_long) - - def test_model_loading_old_rope_configs(self): - def _reinitialize_config(base_config, new_kwargs): - # Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation - # steps. - base_config_dict = base_config.to_dict() - new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs}) - return new_config - - # from untouched config -> ✅ - base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common() - original_model = LlamaForCausalLM(base_config).to(torch_device) - original_model(**model_inputs) - - # from a config with the expected rope configuration -> ✅ - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}}) - original_model = LlamaForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC - config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}}) - original_model = LlamaForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config) - config = _reinitialize_config( - base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}} - ) - self.assertTrue(config.rope_scaling["type"] == "linear") - self.assertTrue(config.rope_scaling["rope_type"] == "linear") - original_model = LlamaForCausalLM(config).to(torch_device) - original_model(**model_inputs) - - # from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}}) - original_model = LlamaForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("factor field", logs.output[0]) - - # from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning - with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs: - config = _reinitialize_config( - base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}} - ) - original_model = LlamaForCausalLM(config).to(torch_device) - original_model(**model_inputs) - self.assertEqual(len(logs.output), 1) - self.assertIn("Unrecognized keys", logs.output[0]) - - # from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception - with self.assertRaises(KeyError): - config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor" - @require_torch_accelerator class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 7eee96f2ef9..bb5a24c3cec 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -34,11 +34,6 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -51,131 +46,21 @@ if is_torch_available(): MistralModel, ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -class MistralModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return MistralConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mistral - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = MistralModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class MistralModelTester(CausalLMModelTester): + config_class = MistralConfig + if is_torch_available(): + base_model_class = MistralModel + causal_lm_class = MistralForCausalLM + sequence_class = MistralForSequenceClassification + token_class = MistralForTokenClassification + question_answering_class = MistralForQuestionAnswering @require_torch -class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class MistralModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( MistralModel, @@ -193,7 +78,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi "text-classification": MistralForSequenceClassification, "token-classification": MistralForTokenClassification, "text-generation": MistralForCausalLM, - "zero-shot": MistralForSequenceClassification, "question-answering": MistralForQuestionAnswering, } if is_torch_available() @@ -201,7 +85,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = MistralModelTester # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( @@ -216,82 +100,6 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ): return True - def setUp(self): - self.model_tester = MistralModelTester(self) - self.config_tester = ConfigTester(self, config_class=MistralConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_torch_fx_output_loss(self): - super().test_torch_fx_output_loss() - - def test_Mistral_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = MistralForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Mistral_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = MistralForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Mistral_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = MistralForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mistral,llama->Mistral - def test_Mistral_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = MistralForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 8f4215c7205..532ebb7348a 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -27,11 +27,6 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -44,137 +39,21 @@ if is_torch_available(): MixtralModel, ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -class MixtralModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - router_jitter_noise=0.1, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - self.router_jitter_noise = router_jitter_noise - # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return MixtralConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - num_experts_per_tok=2, - num_local_experts=2, - router_jitter_noise=self.router_jitter_noise, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = MixtralModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Mixtral - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class MixtralModelTester(CausalLMModelTester): + config_class = MixtralConfig + if is_torch_available(): + base_model_class = MixtralModel + causal_lm_class = MixtralForCausalLM + sequence_class = MixtralForSequenceClassification + token_class = MixtralForTokenClassification + question_answering_class = MixtralForQuestionAnswering @require_torch -# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Mixtral -class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class MistralModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( MixtralModel, @@ -192,15 +71,15 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi "text-classification": MixtralForSequenceClassification, "token-classification": MixtralForTokenClassification, "text-generation": MixtralForCausalLM, - "zero-shot": MixtralForSequenceClassification, "question-answering": MixtralForQuestionAnswering, } if is_torch_available() else {} ) + test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = MixtralModelTester # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( @@ -215,88 +94,12 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ): return True - def setUp(self): - self.model_tester = MixtralModelTester(self) - self.config_tester = ConfigTester(self, config_class=MixtralConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_torch_fx_output_loss(self): - super().test_torch_fx_output_loss() - - def test_Mixtral_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = MixtralForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Mixtral_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = MixtralForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Mixtral_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = MixtralForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Mixtral,llama->Mixtral - def test_Mixtral_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = MixtralForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @slow def test_flash_attn_2_inference_equivalence_right_padding(self): - self.skipTest(reason="Mixtral flash attention does not support right padding") + self.skipTest(reason="Mistral flash attention does not support right padding") # Ignore copy def test_load_balancing_loss(self): diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py index 6dd2fb5cd65..9ef543edeb7 100644 --- a/tests/models/nemotron/test_modeling_nemotron.py +++ b/tests/models/nemotron/test_modeling_nemotron.py @@ -14,25 +14,19 @@ # limitations under the License. """Testing suite for the PyTorch Nemotron model.""" -import tempfile import unittest -import pytest - from transformers import NemotronConfig, is_torch_available from transformers.testing_utils import ( Expectations, - is_flaky, - require_flash_attn, require_read_token, require_torch, require_torch_accelerator, - require_torch_gpu, slow, torch_device, ) -from ...models.gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester from ...test_configuration_common import ConfigTester @@ -49,17 +43,18 @@ if is_torch_available(): ) -class NemotronModelTester(GemmaModelTester): +class NemotronModelTester(CausalLMModelTester): if is_torch_available(): config_class = NemotronConfig - model_class = NemotronModel - for_causal_lm_class = NemotronForCausalLM - for_sequence_class = NemotronForSequenceClassification - for_token_class = NemotronForTokenClassification + base_model_class = NemotronModel + causal_lm_class = NemotronForCausalLM + sequence_class = NemotronForSequenceClassification + token_class = NemotronForTokenClassification @require_torch -class NemotronModelTest(GemmaModelTest): +class NemotronModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = NemotronModelTester # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] @@ -101,40 +96,6 @@ class NemotronModelTest(GemmaModelTest): def test_model_outputs_equivalence(self, **kwargs): pass - @require_flash_attn - @require_torch_gpu - @pytest.mark.flash_attn_test - @is_flaky() - @slow - def test_flash_attn_2_equivalence(self): - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(reason="Model does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" - ) - model_fa.to(torch_device) - - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") - model.to(torch_device) - - dummy_input = inputs_dict[model_class.main_input_name] - dummy_input = dummy_input.to(torch_device) - outputs = model(dummy_input, output_hidden_states=True) - outputs_fa = model_fa(dummy_input, output_hidden_states=True) - - logits = outputs.hidden_states[-1] - logits_fa = outputs_fa.hidden_states[-1] - - # nemotron flash attention 2 needs a high tolerance - assert torch.allclose(logits_fa, logits, atol=1e-2) - @require_torch_accelerator class NemotronIntegrationTest(unittest.TestCase): diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 3b50cb67052..2ac23a7e306 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -16,9 +16,7 @@ import gc import unittest -from parameterized import parameterized - -from transformers import PersimmonConfig, is_torch_available, set_seed +from transformers import PersimmonConfig, is_torch_available from transformers.testing_utils import ( backend_empty_cache, require_bitsandbytes, @@ -29,11 +27,6 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -45,128 +38,22 @@ if is_torch_available(): PersimmonForTokenClassification, PersimmonModel, ) - from transformers.models.persimmon.modeling_persimmon import PersimmonRotaryEmbedding + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon -class PersimmonModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return PersimmonConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = PersimmonModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class PersimmonModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = PersimmonConfig + base_model_class = PersimmonModel + causal_lm_class = PersimmonForCausalLM + sequence_class = PersimmonForSequenceClassification + token_class = PersimmonForTokenClassification @require_torch -class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class PersimmonModelTest(CausalLMModelTest, unittest.TestCase): + model_tester_class = PersimmonModelTester all_model_classes = ( (PersimmonModel, PersimmonForCausalLM, PersimmonForSequenceClassification, PersimmonForTokenClassification) if is_torch_available() @@ -184,173 +71,11 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester if is_torch_available() else {} ) + model_tester_class = PersimmonModelTester test_headmasking = False test_pruning = False - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon - def setUp(self): - self.model_tester = PersimmonModelTester(self) - self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config - def test_config(self): - self.config_tester.run_common_tests() - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon - def test_persimmon_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = PersimmonForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon - def test_persimmon_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = PersimmonForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon - def test_persimmon_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = PersimmonForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Persimmon,llama->persimmon - def test_persimmon_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = PersimmonForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = PersimmonModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = PersimmonModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Persimmon - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = PersimmonRotaryEmbedding(config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - @require_torch class PersimmonIntegrationTest(unittest.TestCase): diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index bda6cf98ec4..c2a7f26b31f 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -16,19 +16,14 @@ import unittest -from parameterized import parameterized - -from transformers import PhiConfig, is_torch_available, set_seed +from transformers import PhiConfig, is_torch_available from transformers.testing_utils import ( require_torch, slow, torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -44,124 +39,17 @@ if is_torch_available(): from transformers.models.phi.modeling_phi import PhiRotaryEmbedding -class PhiModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = random_attention_mask([self.batch_size, self.seq_length]) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return PhiConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = PhiModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class PhiModelTester(CausalLMModelTester): + config_class = PhiConfig + if is_torch_available(): + base_model_class = PhiModel + causal_lm_class = PhiForCausalLM + sequence_class = PhiForSequenceClassification + token_class = PhiForTokenClassification @require_torch -class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class PhiModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (PhiModel, PhiForCausalLM, PhiForSequenceClassification, PhiForTokenClassification) if is_torch_available() @@ -171,9 +59,8 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, { "feature-extraction": PhiModel, "text-classification": PhiForSequenceClassification, - "text-generation": PhiForCausalLM, "token-classification": PhiForTokenClassification, - "zero-shot": PhiForSequenceClassification, + "text-generation": PhiForCausalLM, } if is_torch_available() else {} @@ -181,6 +68,8 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, test_headmasking = False test_pruning = False + model_tester_class = PhiModelTester + rotary_embedding_layer = PhiRotaryEmbedding # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905 def is_pipeline_test_to_skip( @@ -195,146 +84,6 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ): return True - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phi - def setUp(self): - self.model_tester = PhiModelTester(self) - self.config_tester = ConfigTester(self, config_class=PhiConfig, hidden_size=37) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config - def test_config(self): - self.config_tester.run_common_tests() - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phi,llama->phi - def test_phi_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = PhiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phi,llama->phi - def test_phi_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = PhiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phi,llama->phi - def test_phi_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = PhiForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = PhiModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = PhiModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Phi - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = PhiRotaryEmbedding(config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = PhiRotaryEmbedding(config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = PhiRotaryEmbedding(config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) - @slow @require_torch diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index c2af64ffd8a..cb9dc86d43b 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -16,9 +16,7 @@ import unittest -from parameterized import parameterized - -from transformers import Phi3Config, StaticCache, is_torch_available, set_seed +from transformers import Phi3Config, StaticCache, is_torch_available from transformers.models.auto.configuration_auto import AutoConfig from transformers.testing_utils import ( require_torch, @@ -26,10 +24,7 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -42,6 +37,7 @@ if is_torch_available(): Phi3ForTokenClassification, Phi3Model, ) + from transformers.models.phi3.modeling_phi3 import Phi3RotaryEmbedding end_of_text_token = 32000 @@ -93,127 +89,17 @@ if is_torch_available(): return response_tokens -class Phi3ModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return Phi3Config( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Phi3 - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = Phi3Model(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class Phi3ModelTester(CausalLMModelTester): + config_class = Phi3Config + if is_torch_available(): + base_model_class = Phi3Model + causal_lm_class = Phi3ForCausalLM + sequence_class = Phi3ForSequenceClassification + token_class = Phi3ForTokenClassification @require_torch -class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class Phi3ModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (Phi3Model, Phi3ForCausalLM, Phi3ForSequenceClassification, Phi3ForTokenClassification) if is_torch_available() @@ -223,9 +109,8 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin { "feature-extraction": Phi3Model, "text-classification": Phi3ForSequenceClassification, - "text-generation": Phi3ForCausalLM, "token-classification": Phi3ForTokenClassification, - "zero-shot": Phi3ForSequenceClassification, + "text-generation": Phi3ForCausalLM, } if is_torch_available() else {} @@ -233,150 +118,8 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin test_headmasking = False test_pruning = False - - # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905 - def is_pipeline_test_to_skip( - self, - pipeline_test_case_name, - config_class, - model_architecture, - tokenizer_name, - image_processor_name, - feature_extractor_name, - processor_name, - ): - return True - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phi3 - def setUp(self): - self.model_tester = Phi3ModelTester(self) - self.config_tester = ConfigTester(self, config_class=Phi3Config, hidden_size=37) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config - def test_config(self): - self.config_tester.run_common_tests() - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phi3,llama->phi3 - def test_phi3_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Phi3ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phi3,llama->phi3 - def test_phi3_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Phi3ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phi3,llama->phi3 - def test_phi3_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = Phi3ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - @parameterized.expand([("longrope",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = Phi3Model(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - n_factors = config.hidden_size // config.num_attention_heads // 2 - config.rope_scaling = { - "type": scaling_type, - "short_factor": [5.0 for _ in range(n_factors)], - "long_factor": [5.0 for _ in range(n_factors)], - } - scaled_model = Phi3Model(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Scaling changes the RoPE embeddings, both for the short and long outputs - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - @parameterized.expand([("longrope",)]) - def test_model_rope_scaling_short_long_factor(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - n_factors = config.hidden_size // config.num_key_value_heads // 2 - config.rope_scaling = { - "type": scaling_type, - "short_factor": [3.0 for _ in range(n_factors)], - "long_factor": [5.0 for _ in range(n_factors)], - } - input_tensor = ids_tensor([1, 4090], config.vocab_size) - # Make sure we don't have padding tokens. If this is the case, then the actual number of "true" tokens may be shorter - # than `config.original_max_position_embeddings + 5`, invalidating this test - input_tensor[input_tensor == config.pad_token_id] += 1 - model = Phi3ForCausalLM(config) - model.to(torch_device) - model.eval() - generation_args_short = { - "max_length": config.original_max_position_embeddings, - "temperature": 0.0, - "use_cache": True, - "do_sample": False, - "return_dict_in_generate": True, - } - output_with_short_factor = model.generate(input_tensor, **generation_args_short) - keys_with_short_factor = output_with_short_factor.past_key_values[0][0] - generation_args_long = { - "max_length": config.original_max_position_embeddings + 5, - "temperature": 0.0, - "use_cache": True, - "do_sample": False, - "return_dict_in_generate": True, - "output_logits": True, - } - output_with_long_factor = model.generate(input_tensor, **generation_args_long) - keys_with_long_factor = output_with_long_factor.past_key_values[0][0] - last_token_logits = output_with_long_factor.logits[-1][-1] - regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1] - keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :] - - # KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position - self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2)) - # Last token generated using long factor - torch.testing.assert_close(last_token_logits, regenerated_last_token_logits, rtol=1e-2, atol=1e-2) + model_tester_class = Phi3ModelTester + rotary_embedding_layer = Phi3RotaryEmbedding @slow diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index 7f548bd2dc0..89bde307b6d 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -16,20 +16,14 @@ import unittest -from parameterized import parameterized - -from transformers import PhimoeConfig, StaticCache, is_torch_available, set_seed +from transformers import PhimoeConfig, StaticCache, is_torch_available from transformers.testing_utils import ( - is_flaky, require_torch, slow, torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -92,138 +86,23 @@ if is_torch_available(): return response_tokens -class PhimoeModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=131072, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - original_max_position_embeddings=4096, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - self.original_max_position_embeddings = original_max_position_embeddings - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return PhimoeConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - num_experts_per_tok=2, - num_local_experts=2, - original_max_position_embeddings=self.original_max_position_embeddings, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Phimoe - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = PhimoeModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class PhimoeModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = PhimoeConfig + base_model_class = PhimoeModel + causal_lm_class = PhimoeForCausalLM + sequence_class = PhimoeForSequenceClassification @require_torch -class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class PhimoeModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (PhimoeModel, PhimoeForCausalLM, PhimoeForSequenceClassification) if is_torch_available() else () ) + + test_headmasking = False + test_pruning = False + model_tester_class = PhimoeModelTester pipeline_model_mapping = ( { "feature-extraction": PhimoeModel, @@ -235,150 +114,12 @@ class PhimoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix else {} ) - test_headmasking = False - test_pruning = False - # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79292/workflows/fa2ba644-8953-44a6-8f67-ccd69ca6a476/jobs/1012905 def is_pipeline_test_to_skip( self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name ): return True - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Phimoe - def setUp(self): - self.model_tester = PhimoeModelTester(self) - self.config_tester = ConfigTester(self, config_class=PhimoeConfig, hidden_size=37) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config - def test_config(self): - self.config_tester.run_common_tests() - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Phimoe,llama->phimoe - def test_phimoe_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = PhimoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Phimoe,llama->phimoe - def test_phimoe_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = PhimoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Phimoe,llama->phimoe - def test_phimoe_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = PhimoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - @parameterized.expand([("longrope",)]) - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.original_max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = PhimoeModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - n_factors = config.hidden_size // config.num_attention_heads // 2 - config.rope_scaling = { - "type": scaling_type, - "short_factor": [3.0 for _ in range(n_factors)], - "long_factor": [5.0 for _ in range(n_factors)], - "short_mscale": 1.243163121016122, - "long_mscale": 1.243163121016122, - "original_max_position_embeddings": 4096, - } - scaled_model = PhimoeModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Scaling changes the RoPE embeddings, both for the short and long outputs - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - @parameterized.expand([("longrope",)]) - @is_flaky() # TODO (joao): unify rope tests in the mixin - def test_model_rope_scaling_short_long_factor(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - n_factors = config.hidden_size // config.num_key_value_heads // 2 - config.rope_scaling = { - "type": scaling_type, - "short_factor": [3.0 for _ in range(n_factors)], - "long_factor": [5.0 for _ in range(n_factors)], - "short_mscale": 1.243163121016122, - "long_mscale": 1.243163121016122, - "original_max_position_embeddings": 4096, - } - input_tensor = ids_tensor([1, 4090], config.vocab_size) - model = PhimoeForCausalLM(config) - model.to(torch_device) - model.eval() - generation_args_short = { - "max_length": config.original_max_position_embeddings, - "temperature": 0.0, - "use_cache": True, - "do_sample": False, - "return_dict_in_generate": True, - } - output_with_short_factor = model.generate(input_tensor, **generation_args_short) - keys_with_short_factor = output_with_short_factor.past_key_values[0][0] - generation_args_long = { - "max_length": config.original_max_position_embeddings + 5, - "temperature": 0.0, - "use_cache": True, - "do_sample": False, - "return_dict_in_generate": True, - "output_logits": True, - } - output_with_long_factor = model.generate(input_tensor, **generation_args_long) - keys_with_long_factor = output_with_long_factor.past_key_values[0][0] - last_token_logits = output_with_long_factor.logits[-1][-1] - regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1] - keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :] - - # KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position - self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-3, rtol=1e-3)) - # Last token generated using long factor - torch.testing.assert_close(last_token_logits, regenerated_last_token_logits, rtol=1e-2, atol=1e-2) - @slow @require_torch diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index acb784f6ab2..a27695fa9d2 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -33,11 +33,6 @@ from transformers.testing_utils import ( ) from transformers.utils.import_utils import is_torch_greater_or_equal -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -51,143 +46,21 @@ if is_torch_available(): ) -class Qwen2ModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - max_window_layers=3, - use_sliding_window=True, - sliding_window=50, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - bos_token_id=1, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.max_window_layers = max_window_layers - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.scope = scope +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return Qwen2Config( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - max_window_layers=self.max_window_layers, - use_sliding_window=self.use_sliding_window, - sliding_window=self.sliding_window, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2 - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = Qwen2Model(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class Qwen2ModelTester(CausalLMModelTester): + config_class = Qwen2Config + if is_torch_available(): + base_model_class = Qwen2Model + causal_lm_class = Qwen2ForCausalLM + sequence_class = Qwen2ForSequenceClassification + token_class = Qwen2ForTokenClassification + question_answering_class = Qwen2ForQuestionAnswering @require_torch -# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2 -class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class Qwen2ModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( Qwen2Model, @@ -199,21 +72,20 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi if is_torch_available() else () ) + test_headmasking = False + test_pruning = False + model_tester_class = Qwen2ModelTester pipeline_model_mapping = ( { "feature-extraction": Qwen2Model, "text-classification": Qwen2ForSequenceClassification, "token-classification": Qwen2ForTokenClassification, "text-generation": Qwen2ForCausalLM, - "zero-shot": Qwen2ForSequenceClassification, "question-answering": Qwen2ForQuestionAnswering, } if is_torch_available() else {} ) - test_headmasking = False - test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( @@ -228,82 +100,6 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ): return True - def setUp(self): - self.model_tester = Qwen2ModelTester(self) - self.config_tester = ConfigTester(self, config_class=Qwen2Config, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_torch_fx_output_loss(self): - super().test_torch_fx_output_loss() - - def test_Qwen2_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen2ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen2_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen2ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen2_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = Qwen2ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2,llama->Qwen2 - def test_Qwen2_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = Qwen2ForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index 5a1e7615fff..dbfc7a1e684 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -30,11 +30,6 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -48,173 +43,21 @@ if is_torch_available(): ) -class Qwen2MoeModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - max_window_layers=3, - use_sliding_window=True, - sliding_window=50, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - expert_interval=1, - moe_intermediate_size=12, - shared_expert_intermediate_size=36, - shared_expert_gate=True, - num_experts_per_tok=2, - num_experts=8, - norm_topk_prob=False, - output_router_logits=False, - router_aux_loss_coef=0.001, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - bos_token_id=1, - scope=None, - qkv_bias=False, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.max_window_layers = max_window_layers - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.scope = scope - self.expert_interval = expert_interval - self.moe_intermediate_size = moe_intermediate_size - self.shared_expert_intermediate_size = shared_expert_intermediate_size - self.shared_expert_gate = shared_expert_gate - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.norm_topk_prob = norm_topk_prob - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - self.qkv_bias = qkv_bias +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return Qwen2MoeConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - max_window_layers=self.max_window_layers, - use_sliding_window=self.use_sliding_window, - sliding_window=self.sliding_window, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - expert_interval=self.expert_interval, - moe_intermediate_size=self.moe_intermediate_size, - shared_expert_intermediate_size=self.shared_expert_intermediate_size, - shared_expert_gate=self.shared_expert_gate, - num_experts_per_tok=self.num_experts_per_tok, - num_experts=self.num_experts, - norm_topk_prob=self.norm_topk_prob, - output_router_logits=self.output_router_logits, - router_aux_loss_coef=self.router_aux_loss_coef, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - qkv_bias=self.qkv_bias, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen2Moe - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = Qwen2MoeModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class Qwen2MoeModelTester(CausalLMModelTester): + config_class = Qwen2MoeConfig + if is_torch_available(): + base_model_class = Qwen2MoeModel + causal_lm_class = Qwen2MoeForCausalLM + sequence_class = Qwen2MoeForSequenceClassification + token_class = Qwen2MoeForTokenClassification + question_answering_class = Qwen2MoeForQuestionAnswering @require_torch -# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen2Moe -class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class Qwen2MoeModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( Qwen2MoeModel, @@ -232,15 +75,15 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM "text-classification": Qwen2MoeForSequenceClassification, "token-classification": Qwen2MoeForTokenClassification, "text-generation": Qwen2MoeForCausalLM, - "zero-shot": Qwen2MoeForSequenceClassification, "question-answering": Qwen2MoeForQuestionAnswering, } if is_torch_available() else {} ) + test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = Qwen2MoeModelTester # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( @@ -255,82 +98,6 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ): return True - def setUp(self): - self.model_tester = Qwen2MoeModelTester(self) - self.config_tester = ConfigTester(self, config_class=Qwen2MoeConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_torch_fx_output_loss(self): - super().test_torch_fx_output_loss() - - def test_Qwen2Moe_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen2MoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen2Moe_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen2MoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen2Moe_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = Qwen2MoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen2Moe,llama->Qwen2Moe - def test_Qwen2Moe_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = Qwen2MoeForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen3/test_modeling_qwen3.py b/tests/models/qwen3/test_modeling_qwen3.py index 44eb7474fa8..884c1ea077b 100644 --- a/tests/models/qwen3/test_modeling_qwen3.py +++ b/tests/models/qwen3/test_modeling_qwen3.py @@ -33,11 +33,6 @@ from transformers.testing_utils import ( ) from transformers.utils.import_utils import is_torch_greater_or_equal -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -50,147 +45,21 @@ if is_torch_available(): Qwen3Model, ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -class Qwen3ModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=64, - num_hidden_layers=5, - max_window_layers=3, - use_sliding_window=True, - sliding_window=50, - num_attention_heads=4, - num_key_value_heads=2, - head_dim=16, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - bos_token_id=1, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.max_window_layers = max_window_layers - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.head_dim = head_dim - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.scope = scope - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return Qwen3Config( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - max_window_layers=self.max_window_layers, - use_sliding_window=self.use_sliding_window, - sliding_window=self.sliding_window, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3 - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = Qwen3Model(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class Qwen3ModelTester(CausalLMModelTester): + config_class = Qwen3Config + if is_torch_available(): + base_model_class = Qwen3Model + causal_lm_class = Qwen3ForCausalLM + sequence_class = Qwen3ForSequenceClassification + token_class = Qwen3ForTokenClassification + question_answering_class = Qwen3ForQuestionAnswering @require_torch -# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3 -class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class Qwen3ModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( Qwen3Model, @@ -202,21 +71,20 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi if is_torch_available() else () ) + test_headmasking = False + test_pruning = False + model_tester_class = Qwen3ModelTester pipeline_model_mapping = ( { "feature-extraction": Qwen3Model, "text-classification": Qwen3ForSequenceClassification, "token-classification": Qwen3ForTokenClassification, "text-generation": Qwen3ForCausalLM, - "zero-shot": Qwen3ForSequenceClassification, "question-answering": Qwen3ForQuestionAnswering, } if is_torch_available() else {} ) - test_headmasking = False - test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( @@ -231,82 +99,6 @@ class Qwen3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ): return True - def setUp(self): - self.model_tester = Qwen3ModelTester(self) - self.config_tester = ConfigTester(self, config_class=Qwen3Config, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_torch_fx_output_loss(self): - super().test_torch_fx_output_loss() - - def test_Qwen3_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen3ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen3_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen3ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen3_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = Qwen3ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3,llama->Qwen3 - def test_Qwen3_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = Qwen3ForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py index af3cf160322..0ffb74c6c24 100644 --- a/tests/models/qwen3_moe/test_modeling_qwen3_moe.py +++ b/tests/models/qwen3_moe/test_modeling_qwen3_moe.py @@ -30,185 +30,33 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch from transformers import ( + Qwen3ForQuestionAnswering, Qwen3MoeForCausalLM, Qwen3MoeForQuestionAnswering, Qwen3MoeForSequenceClassification, Qwen3MoeForTokenClassification, Qwen3MoeModel, ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -class Qwen3MoeModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=64, - num_hidden_layers=5, - max_window_layers=3, - use_sliding_window=True, - sliding_window=50, - num_attention_heads=4, - num_key_value_heads=2, - head_dim=16, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - expert_interval=1, - moe_intermediate_size=12, - num_experts_per_tok=2, - num_experts=8, - norm_topk_prob=False, - output_router_logits=False, - router_aux_loss_coef=0.001, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - bos_token_id=1, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.max_window_layers = max_window_layers - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.head_dim = head_dim - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.scope = scope - self.expert_interval = expert_interval - self.moe_intermediate_size = moe_intermediate_size - self.num_experts_per_tok = num_experts_per_tok - self.num_experts = num_experts - self.norm_topk_prob = norm_topk_prob - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return Qwen3MoeConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - max_window_layers=self.max_window_layers, - use_sliding_window=self.use_sliding_window, - sliding_window=self.sliding_window, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - head_dim=self.head_dim, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - expert_interval=self.expert_interval, - moe_intermediate_size=self.moe_intermediate_size, - num_experts_per_tok=self.num_experts_per_tok, - num_experts=self.num_experts, - norm_topk_prob=self.norm_topk_prob, - output_router_logits=self.output_router_logits, - router_aux_loss_coef=self.router_aux_loss_coef, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - bos_token_id=self.bos_token_id, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Qwen3Moe - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = Qwen3MoeModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class Qwen3MoeModelTester(CausalLMModelTester): + config_class = Qwen3MoeConfig + if is_torch_available(): + base_model_class = Qwen3MoeModel + causal_lm_class = Qwen3MoeForCausalLM + sequence_class = Qwen3MoeForSequenceClassification + token_class = Qwen3MoeForTokenClassification + question_answering_class = Qwen3MoeForQuestionAnswering @require_torch -# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Qwen3Moe -class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class Qwen3MoeModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( ( Qwen3MoeModel, @@ -226,15 +74,15 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM "text-classification": Qwen3MoeForSequenceClassification, "token-classification": Qwen3MoeForTokenClassification, "text-generation": Qwen3MoeForCausalLM, - "zero-shot": Qwen3MoeForSequenceClassification, - "question-answering": Qwen3MoeForQuestionAnswering, + "question-answering": Qwen3ForQuestionAnswering, } if is_torch_available() else {} ) + test_headmasking = False test_pruning = False - fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = Qwen3MoeModelTester # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( @@ -249,82 +97,6 @@ class Qwen3MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ): return True - def setUp(self): - self.model_tester = Qwen3MoeModelTester(self) - self.config_tester = ConfigTester(self, config_class=Qwen3MoeConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_torch_fx_output_loss(self): - super().test_torch_fx_output_loss() - - def test_Qwen3Moe_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen3MoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen3Moe_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Qwen3MoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Qwen3Moe_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = Qwen3MoeForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Qwen3Moe,llama->Qwen3Moe - def test_Qwen3Moe_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = Qwen3MoeForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index 640c8cfa804..f3d8b15dde7 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -16,6 +16,7 @@ import unittest import pytest +from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers.testing_utils import ( @@ -27,151 +28,26 @@ from transformers.testing_utils import ( torch_device, ) -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch - from transformers import RecurrentGemmaForCausalLM, RecurrentGemmaModel + from transformers import RecurrentGemmaConfig, RecurrentGemmaForCausalLM, RecurrentGemmaModel -class RecurrentGemmaModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=12, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - num_hidden_layers=3, - vocab_size=99, - hidden_size=32, - intermediate_size=3 * 32, - num_attention_heads=2, - lru_width=2 * 32, - embeddings_scale_by_sqrt_dim=True, - attention_window_size=16, - conv1d_width=4, - logits_soft_cap=30.0, - rms_norm_eps=1e-6, - use_cache=True, - rope_theta=10000.0, - type_vocab_size=16, - type_sequence_label_size=2, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester - self.num_hidden_layers = num_hidden_layers - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_attention_heads = num_attention_heads - self.lru_width = lru_width if lru_width is not None else hidden_size - self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim - self.attention_window_size = attention_window_size - self.conv1d_width = conv1d_width - self.logits_soft_cap = logits_soft_cap - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - - # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return RecurrentGemmaConfig( - num_hidden_layers=self.num_hidden_layers, - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, - num_attention_heads=self.num_attention_heads, - lru_width=self.lru_width, - embeddings_scale_by_sqrt_dim=self.embeddings_scale_by_sqrt_dim, - attention_window_size=self.attention_window_size, - conv1d_width=self.conv1d_width, - logits_soft_cap=self.logits_soft_cap, - rms_norm_eps=self.rms_norm_eps, - use_cache=self.use_cache, - rope_theta=self.rope_theta, - pad_token_id=self.pad_token_id, - output_attentions=False, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->RecurrentGemma - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = RecurrentGemmaModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->RecurrentGemma - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class RecurrentGemmaModelTester(CausalLMModelTester): + config_class = RecurrentGemmaConfig + if is_torch_available(): + base_model_class = RecurrentGemmaModel + causal_lm_class = RecurrentGemmaForCausalLM @require_torch -class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else () - # Doesn't run generation tests. TODO @gante not fully supported - all_generative_model_classes = () +class RecurrentGemmaModelTest(CausalLMModelTest, unittest.TestCase): + all_model_classes = (RecurrentGemmaModel, RecurrentGemmaForCausalLM) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": RecurrentGemmaModel, @@ -180,48 +56,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te if is_torch_available() else {} ) - fx_compatible = False # FIXME let's try to support this @ArthurZucker - test_torchscript = False # FIXME let's try to support this @ArthurZucker - test_missing_keys = False - test_model_parallel = False + test_headmasking = False test_pruning = False - test_head_masking = False # RecurrentGemma does not have attention heads - - # Need to remove 0.9 in `test_cpu_offload` - # This is because we are hitting edge cases with the causal_mask buffer - model_split_percents = [0.5, 0.6] - - # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 - def is_pipeline_test_to_skip( - self, - pipeline_test_case_name, - config_class, - model_architecture, - tokenizer_name, - image_processor_name, - feature_extractor_name, - processor_name, - ): - return True - - def setUp(self): - # We don't output attentions - self.has_attentions = False - self.model_tester = RecurrentGemmaModelTester(self) - self.config_tester = ConfigTester(self, config_class=RecurrentGemmaConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) + has_attentions = False + model_tester_class = RecurrentGemmaModelTester @unittest.skip(reason="RecurrentGemma only supports sdpa") def test_eager_matches_sdpa_generate(self): @@ -255,6 +93,7 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te def test_model_parallel_beam_search(self): pass + @parameterized.expand([("random",), ("same",)]) @pytest.mark.generate @unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported") def test_assisted_decoding_matches_greedy_search(self): @@ -273,6 +112,65 @@ class RecurrentGemmaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te def test_initialization(self): pass + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_beam_sample_generate_dict_output(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_beam_search_generate_dict_output(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_beam_search_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_constrained_beam_search_generate_dict_output(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_dola_decoding_sample(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_generate_without_input_ids(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_group_beam_search_generate(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_group_beam_search_generate_dict_output(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_constrained_beam_search_generate(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_greedy_generate_dict_outputs(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + @pytest.mark.generate + def test_greedy_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests") + def test_model_outputs_equivalence(self): + pass + @require_torch_accelerator @slow diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index a984d908eab..87555a7d774 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -16,9 +16,8 @@ import unittest import pytest -from parameterized import parameterized -from transformers import StableLmConfig, is_torch_available, set_seed +from transformers import StableLmConfig, is_torch_available from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -27,11 +26,6 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -45,133 +39,27 @@ if is_torch_available(): ) from transformers.models.stablelm.modeling_stablelm import StableLmRotaryEmbedding +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -# Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm -class StableLmModelTester: - # Ignore copy - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=64, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - return StableLmConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - ) - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = StableLmModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class StableLmModelTester(CausalLMModelTester): + if is_torch_available(): + config_class = StableLmConfig + base_model_class = StableLmModel + causal_lm_class = StableLmForCausalLM + sequence_class = StableLmForSequenceClassification + token_class = StableLmForTokenClassification @require_torch -# Copied from transformers.tests.persimmon.test_modeling_persimmon.PersimmonModelTest with Persimmon -> StableLm -class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableLmModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( - (StableLmModel, StableLmForCausalLM, StableLmForSequenceClassification, StableLmForTokenClassification) + ( + StableLmModel, + StableLmForCausalLM, + StableLmForSequenceClassification, + StableLmForTokenClassification, + ) if is_torch_available() else () ) @@ -179,167 +67,18 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM { "feature-extraction": StableLmModel, "text-classification": StableLmForSequenceClassification, + "text-generation": StableLmForCausalLM, + "zero-shot": StableLmForSequenceClassification, "token-classification": StableLmForTokenClassification, - # TODO (ydshieh): check why these two fail. Fix them or skip them in a better way. - # "text-generation": StableLmForCausalLM, - # "zero-shot": StableLmForSequenceClassification, } if is_torch_available() else {} ) - test_headmasking = False test_pruning = False - - def setUp(self): - self.model_tester = StableLmModelTester(self) - self.config_tester = ConfigTester(self, config_class=StableLmConfig, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_stablelm_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = StableLmForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_stablelm_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = StableLmForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_stablelm_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = StableLmForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->StableLm,llama->stablelm - def test_stablelm_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = StableLmForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) - - @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm - def test_model_rope_scaling_from_config(self, scaling_type): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - short_input = ids_tensor([1, 10], config.vocab_size) - long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - original_model = StableLmModel(config) - original_model.to(torch_device) - original_model.eval() - original_short_output = original_model(short_input).last_hidden_state - original_long_output = original_model(long_input).last_hidden_state - - set_seed(42) # Fixed seed at init time so the two models get the same random weights - config.rope_scaling = {"type": scaling_type, "factor": 10.0} - scaled_model = StableLmModel(config) - scaled_model.to(torch_device) - scaled_model.eval() - scaled_short_output = scaled_model(short_input).last_hidden_state - scaled_long_output = scaled_model(long_input).last_hidden_state - - # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original - # maximum sequence length, so the outputs for the short input should match. - if scaling_type == "dynamic": - torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5) - else: - self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) - - # The output should be different for long inputs - self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - - # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->StableLm - def test_model_rope_scaling(self): - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - scaling_factor = 10 - short_input_length = 10 - long_input_length = int(config.max_position_embeddings * 1.5) - - # Inputs - x = torch.randn( - 1, dtype=torch.float32, device=torch_device - ) # used exclusively to get the dtype and the device - position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) - position_ids_short = position_ids_short.unsqueeze(0) - position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) - position_ids_long = position_ids_long.unsqueeze(0) - - # Sanity check original RoPE - original_rope = StableLmRotaryEmbedding(config).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, position_ids_short) - original_cos_long, original_sin_long = original_rope(x, position_ids_long) - torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) - - # Sanity check linear RoPE scaling - # New position "x" should match original position with index "x/scaling_factor" - config.rope_scaling = {"type": "linear", "factor": scaling_factor} - linear_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) - for new_position in range(0, long_input_length, scaling_factor): - original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) - torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) - - # Sanity check Dynamic NTK RoPE scaling - # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase - # with scaling_factor (or that `inv_freq` decreases) - config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} - ntk_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) - torch.testing.assert_close(ntk_cos_short, original_cos_short) - torch.testing.assert_close(ntk_sin_short, original_sin_short) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_cos_long, original_cos_long) - with self.assertRaises(AssertionError): - torch.testing.assert_close(ntk_sin_long, original_sin_long) - self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez + model_tester_class = StableLmModelTester + rotary_embedding_layer = StableLmRotaryEmbedding # Enables RoPE tests if set @require_torch diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index dbc3c0dc807..956b210bae4 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -28,11 +28,6 @@ from transformers.testing_utils import ( torch_device, ) -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor -from ...test_pipeline_mixin import PipelineTesterMixin - if is_torch_available(): import torch @@ -45,241 +40,38 @@ if is_torch_available(): Starcoder2Model, ) +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester -# Copied from transformers.tests.models.mistral.test_modeling_mistral.Starcoder2ModelTester with Mistral->Starcoder2 -class Starcoder2ModelTester: - def __init__( - self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=2, - num_attention_heads=4, - num_key_value_heads=2, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - pad_token_id=0, - scope=None, - ): - self.parent = parent - self.batch_size = batch_size - self.seq_length = seq_length - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.initializer_range = initializer_range - self.num_labels = num_labels - self.num_choices = num_choices - self.pad_token_id = pad_token_id - self.scope = scope - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - # Ignore copy - def get_config(self): - return Starcoder2Config( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, - num_hidden_layers=self.num_hidden_layers, - num_attention_heads=self.num_attention_heads, - num_key_value_heads=self.num_key_value_heads, - intermediate_size=self.intermediate_size, - hidden_act=self.hidden_act, - hidden_dropout_prob=self.hidden_dropout_prob, - attention_probs_dropout_prob=self.attention_probs_dropout_prob, - max_position_embeddings=self.max_position_embeddings, - type_vocab_size=self.type_vocab_size, - is_decoder=False, - initializer_range=self.initializer_range, - pad_token_id=self.pad_token_id, - eos_token_id=self.pad_token_id, - bos_token_id=self.pad_token_id, - ) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Starcoder2 - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = Starcoder2Model(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict +class Starcoder2ModelTester(CausalLMModelTester): + config_class = Starcoder2Config + if is_torch_available(): + base_model_class = Starcoder2Model + causal_lm_class = Starcoder2ForCausalLM + sequence_class = Starcoder2ForSequenceClassification + token_class = Starcoder2ForTokenClassification @require_torch -# Copied from transformers.tests.models.mistral.test_modeling_mistral.MistralModelTest with Mistral->Starcoder2 -class Starcoder2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class Starcoder2ModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = ( (Starcoder2Model, Starcoder2ForCausalLM, Starcoder2ForSequenceClassification, Starcoder2ForTokenClassification) if is_torch_available() else () ) + test_headmasking = False + test_pruning = False + model_tester_class = Starcoder2ModelTester pipeline_model_mapping = ( { "feature-extraction": Starcoder2Model, "text-classification": Starcoder2ForSequenceClassification, "token-classification": Starcoder2ForTokenClassification, "text-generation": Starcoder2ForCausalLM, - "zero-shot": Starcoder2ForSequenceClassification, } if is_torch_available() else {} ) - test_headmasking = False - test_pruning = False - - # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 - def is_pipeline_test_to_skip( - self, - pipeline_test_case_name, - config_class, - model_architecture, - tokenizer_name, - image_processor_name, - feature_extractor_name, - processor_name, - ): - return True - - def setUp(self): - self.model_tester = Starcoder2ModelTester(self) - self.config_tester = ConfigTester(self, config_class=Starcoder2Config, hidden_size=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_model_various_embeddings(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - for type in ["absolute", "relative_key", "relative_key_query"]: - config_and_inputs[0].position_embedding_type = type - self.model_tester.create_and_check_model(*config_and_inputs) - - def test_Starcoder2_sequence_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - print(config) - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Starcoder2ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Starcoder2_sequence_classification_model_for_single_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "single_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - model = Starcoder2ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - def test_Starcoder2_sequence_classification_model_for_multi_label(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - config.problem_type = "multi_label_classification" - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - sequence_labels = ids_tensor( - [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size - ).to(torch.float) - model = Starcoder2ForSequenceClassification(config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Starcoder2,llama->Starcoder2 - def test_Starcoder2_token_classification_model(self): - config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 3 - input_ids = input_dict["input_ids"] - attention_mask = input_ids.ne(1).to(torch_device) - token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels) - model = Starcoder2ForTokenClassification(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=attention_mask, labels=token_labels) - self.assertEqual( - result.logits.shape, - (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels), - ) @require_flash_attn @require_torch_gpu diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 621ab67da03..7bd494e452b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4426,7 +4426,7 @@ class ModelTesterMixin: # comparing softmax-normalized logits: normalized_0 = F.softmax(out_last_tokens, dim=-1) normalized_1 = F.softmax(out_shared_prefix_last_tokens, dim=-1) - torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-3) @slow @require_torch_accelerator