diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index f41d3ab6e32..9807c885605 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -181,11 +181,18 @@ class CausalLMModelTester: return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + @property + def config_args(self): + return list(signature(self.config_class.__init__).parameters.keys()) + def get_config(self): - kwarg_names = list(signature(self.config_class.__init__).parameters.keys()) - kwargs = { - k: getattr(self, k) for k in kwarg_names + self.forced_config_args if hasattr(self, k) and k != "self" - } + kwargs = {} + model_name_to_common_name = {v: k for k, v in self.config_class.attribute_map.items()} + for k in self.config_args + self.forced_config_args: + if hasattr(self, k) and k != "self": + kwargs[k] = getattr(self, k) + elif k in model_name_to_common_name and hasattr(self, model_name_to_common_name[k]): + kwargs[k] = getattr(self, model_name_to_common_name[k]) return self.config_class(**kwargs) def create_and_check_model( diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index 7b6cfc5b081..e89740db616 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -16,12 +16,9 @@ import unittest from transformers import DbrxConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, slow -from ...generation.test_utils import GenerationTesterMixin -from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask -from ...test_pipeline_mixin import PipelineTesterMixin +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester if is_torch_available(): @@ -30,197 +27,74 @@ if is_torch_available(): from transformers import DbrxForCausalLM, DbrxModel -class DbrxModelTester: +class DbrxModelTester(CausalLMModelTester): + config_class = DbrxConfig + if is_torch_available(): + base_model_class = DbrxModel + causal_lm_class = DbrxForCausalLM + def __init__( self, parent, - hidden_size=32, - ffn_hidden_size=32, - num_attention_heads=4, - kv_n_heads=4, - num_hidden_layers=5, - max_position_embeddings=512, - type_vocab_size=16, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=False, - use_labels=True, - use_cache=True, - type_sequence_label_size=2, - num_labels=3, - num_choices=4, - scope=None, clip_qkv=8, rope_theta=500000, attn_config_model_type="", - emb_pdrop=0.0, moe_jitter_eps=0, moe_loss_weight=0.05, - moe_num_experts=16, + moe_num_experts=8, moe_top_k=4, ffn_config_model_type="", - ffn_act_fn_name="gelu", initializer_range=0.02, - output_router_logits=False, resid_pdrop=0.0, - tie_word_embeddings=False, - torch_dtype="bfloat16", - vocab_size=99, is_decoder=True, pad_token_id=0, ): - # Parameters unique to testing - self.batch_size = batch_size - self.seq_length = seq_length - self.type_vocab_size = type_vocab_size - self.type_sequence_label_size = type_sequence_label_size - self.num_labels = num_labels - self.num_choices = num_choices - self.scope = scope - self.parent = parent - self.is_training = is_training - self.use_input_mask = use_input_mask - self.use_token_type_ids = use_token_type_ids - self.use_labels = use_labels + # Call parent init + super().__init__( + parent=parent, + hidden_dropout_prob=resid_pdrop, + attention_probs_dropout_prob=resid_pdrop, + initializer_range=initializer_range, + pad_token_id=pad_token_id, + is_decoder=is_decoder, + ) - # attn_config params + # Set DBRX's unusual params self.clip_qkv = clip_qkv - self.kv_n_heads = kv_n_heads - self.rope_theta = rope_theta - self.attn_config_model_type = attn_config_model_type - # ffn_config params - self.ffn_hidden_size = ffn_hidden_size - self.moe_jitter_eps = moe_jitter_eps - self.moe_loss_weight = moe_loss_weight - self.moe_num_experts = moe_num_experts - self.moe_top_k = moe_top_k - self.ffn_config_model_type = ffn_config_model_type - self.ffn_act_fn_name = ffn_act_fn_name - - # Other model params - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.max_position_embeddings = max_position_embeddings - self.vocab_size = vocab_size - self.use_cache = use_cache - self.initializer_range = initializer_range - self.emb_pdrop = emb_pdrop - self.output_router_logits = output_router_logits - self.resid_pdrop = resid_pdrop - self.tie_word_embeddings = tie_word_embeddings - self.torch_dtype = torch_dtype - self.is_decoder = is_decoder - self.pad_token_id = pad_token_id - - # Make the dictionaries + # DBRX takes sub-configurations for the FFN and attention layers, so we need to set that correctly here self.ffn_config = { - "ffn_hidden_size": self.ffn_hidden_size, - "moe_jitter_eps": self.moe_jitter_eps, - "moe_loss_weight": self.moe_loss_weight, - "moe_num_experts": self.moe_num_experts, - "moe_top_k": self.moe_top_k, - "model_type": self.ffn_config_model_type, - "ffn_act_fn": {"name": self.ffn_act_fn_name}, + "ffn_hidden_size": self.hidden_size, + "moe_jitter_eps": moe_jitter_eps, + "moe_loss_weight": moe_loss_weight, + "moe_num_experts": moe_num_experts, + "moe_top_k": moe_top_k, + "model_type": ffn_config_model_type, + "ffn_act_fn": {"name": self.hidden_act}, } self.attn_config = { - "clip_qkv": self.clip_qkv, - "kv_n_heads": self.kv_n_heads, - "model_type": self.attn_config_model_type, - "rope_theta": self.rope_theta, + "clip_qkv": clip_qkv, + "model_type": attn_config_model_type, + "rope_theta": rope_theta, } - def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - - input_mask = None - if self.use_input_mask: - input_mask = random_attention_mask([self.batch_size, self.seq_length]) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - sequence_labels = None - token_labels = None - choice_labels = None - if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) - - config = self.get_config() - - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - - def get_config(self): - # Behind the scenes, `DbrxConfig` maps the parameters `hidden_size`, `num_hidden_layers`, - # `num_attention_heads`, `max_position_embeddings` to the parameters `d_model`, `n_layers`, - # `n_heads`, `max_seq_len` respectively. We use the first group of parameters because - # other tests expect every model to have these parameters with these specific names. - config = DbrxConfig( - vocab_size=self.vocab_size, - hidden_size=self.hidden_size, # mapped to `d_model` - num_hidden_layers=self.num_hidden_layers, # mapped to `n_layers` - num_attention_heads=self.num_attention_heads, # mapped to `n_heads` - max_position_embeddings=self.max_position_embeddings, # mapped to `max_seq_len` - attn_config=self.attn_config, - ffn_config=self.ffn_config, - resid_pdrop=self.resid_pdrop, - emb_pdrop=self.emb_pdrop, - use_cache=self.use_cache, - initializer_range=self.initializer_range, - output_router_logits=self.output_router_logits, - is_decoder=self.is_decoder, - pad_token_id=self.pad_token_id, - ) - return config - - def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - model = DbrxModel(config=config) - model.to(torch_device) - model.eval() - result = model(input_ids, attention_mask=input_mask) - result = model(input_ids) - self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - - def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - token_type_ids, - input_mask, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} - return config, inputs_dict + @property + def config_args(self): + return super().config_args + ["ffn_config", "attn_config"] @require_torch -class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): +class DbrxModelTest(CausalLMModelTest, unittest.TestCase): all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else () - pipeline_model_mapping = {"text-generation": DbrxForCausalLM} if is_torch_available() else {} - test_headmasking = False - test_pruning = False - - def setUp(self): - self.model_tester = DbrxModelTester(self) - self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37) - - def test_config(self): - self.config_tester.run_common_tests() - - def test_model(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_model(*config_and_inputs) + pipeline_model_mapping = ( + { + "feature-extraction": DbrxModel, + "text-generation": DbrxForCausalLM, + } + if is_torch_available() + else {} + ) + model_tester_class = DbrxModelTester def test_model_various_embeddings(self): config_and_inputs = self.model_tester.prepare_config_and_inputs()