mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Refactor DBRX tests to use CausalLMModelTest base classes (#38475)
* Refactor DBRX tests to use CausalLMModelTest base classes - Changed DbrxModelTester to inherit from CausalLMModelTester - Changed DbrxModelTest to inherit from CausalLMModelTest - Removed duplicate methods that are already in base classes - Added required class attributes for model classes - Updated pipeline_model_mapping to include feature-extraction - Kept DBRX-specific configuration and test methods - Disabled RoPE tests as DBRX's rotary embedding doesn't accept config parameter This refactoring reduces code duplication and follows the pattern established in other causal LM model tests like Gemma. * Apply style fixes * Trigger tests * Refactor DBRX test * Make sure the DBRX-specific settings are handled * Use the attribute_map * Fix attribute map --------- Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
64041694a8
commit
b82a45b3b4
@ -181,11 +181,18 @@ class CausalLMModelTester:
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
@property
|
||||
def config_args(self):
|
||||
return list(signature(self.config_class.__init__).parameters.keys())
|
||||
|
||||
def get_config(self):
|
||||
kwarg_names = list(signature(self.config_class.__init__).parameters.keys())
|
||||
kwargs = {
|
||||
k: getattr(self, k) for k in kwarg_names + self.forced_config_args if hasattr(self, k) and k != "self"
|
||||
}
|
||||
kwargs = {}
|
||||
model_name_to_common_name = {v: k for k, v in self.config_class.attribute_map.items()}
|
||||
for k in self.config_args + self.forced_config_args:
|
||||
if hasattr(self, k) and k != "self":
|
||||
kwargs[k] = getattr(self, k)
|
||||
elif k in model_name_to_common_name and hasattr(self, model_name_to_common_name[k]):
|
||||
kwargs[k] = getattr(self, model_name_to_common_name[k])
|
||||
return self.config_class(**kwargs)
|
||||
|
||||
def create_and_check_model(
|
||||
|
@ -16,12 +16,9 @@
|
||||
import unittest
|
||||
|
||||
from transformers import DbrxConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, slow
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -30,197 +27,74 @@ if is_torch_available():
|
||||
from transformers import DbrxForCausalLM, DbrxModel
|
||||
|
||||
|
||||
class DbrxModelTester:
|
||||
class DbrxModelTester(CausalLMModelTester):
|
||||
config_class = DbrxConfig
|
||||
if is_torch_available():
|
||||
base_model_class = DbrxModel
|
||||
causal_lm_class = DbrxForCausalLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
hidden_size=32,
|
||||
ffn_hidden_size=32,
|
||||
num_attention_heads=4,
|
||||
kv_n_heads=4,
|
||||
num_hidden_layers=5,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
use_cache=True,
|
||||
type_sequence_label_size=2,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
clip_qkv=8,
|
||||
rope_theta=500000,
|
||||
attn_config_model_type="",
|
||||
emb_pdrop=0.0,
|
||||
moe_jitter_eps=0,
|
||||
moe_loss_weight=0.05,
|
||||
moe_num_experts=16,
|
||||
moe_num_experts=8,
|
||||
moe_top_k=4,
|
||||
ffn_config_model_type="",
|
||||
ffn_act_fn_name="gelu",
|
||||
initializer_range=0.02,
|
||||
output_router_logits=False,
|
||||
resid_pdrop=0.0,
|
||||
tie_word_embeddings=False,
|
||||
torch_dtype="bfloat16",
|
||||
vocab_size=99,
|
||||
is_decoder=True,
|
||||
pad_token_id=0,
|
||||
):
|
||||
# Parameters unique to testing
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.parent = parent
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
# Call parent init
|
||||
super().__init__(
|
||||
parent=parent,
|
||||
hidden_dropout_prob=resid_pdrop,
|
||||
attention_probs_dropout_prob=resid_pdrop,
|
||||
initializer_range=initializer_range,
|
||||
pad_token_id=pad_token_id,
|
||||
is_decoder=is_decoder,
|
||||
)
|
||||
|
||||
# attn_config params
|
||||
# Set DBRX's unusual params
|
||||
self.clip_qkv = clip_qkv
|
||||
self.kv_n_heads = kv_n_heads
|
||||
self.rope_theta = rope_theta
|
||||
self.attn_config_model_type = attn_config_model_type
|
||||
|
||||
# ffn_config params
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_jitter_eps = moe_jitter_eps
|
||||
self.moe_loss_weight = moe_loss_weight
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.ffn_config_model_type = ffn_config_model_type
|
||||
self.ffn_act_fn_name = ffn_act_fn_name
|
||||
|
||||
# Other model params
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.vocab_size = vocab_size
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
self.emb_pdrop = emb_pdrop
|
||||
self.output_router_logits = output_router_logits
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.torch_dtype = torch_dtype
|
||||
self.is_decoder = is_decoder
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
# Make the dictionaries
|
||||
# DBRX takes sub-configurations for the FFN and attention layers, so we need to set that correctly here
|
||||
self.ffn_config = {
|
||||
"ffn_hidden_size": self.ffn_hidden_size,
|
||||
"moe_jitter_eps": self.moe_jitter_eps,
|
||||
"moe_loss_weight": self.moe_loss_weight,
|
||||
"moe_num_experts": self.moe_num_experts,
|
||||
"moe_top_k": self.moe_top_k,
|
||||
"model_type": self.ffn_config_model_type,
|
||||
"ffn_act_fn": {"name": self.ffn_act_fn_name},
|
||||
"ffn_hidden_size": self.hidden_size,
|
||||
"moe_jitter_eps": moe_jitter_eps,
|
||||
"moe_loss_weight": moe_loss_weight,
|
||||
"moe_num_experts": moe_num_experts,
|
||||
"moe_top_k": moe_top_k,
|
||||
"model_type": ffn_config_model_type,
|
||||
"ffn_act_fn": {"name": self.hidden_act},
|
||||
}
|
||||
self.attn_config = {
|
||||
"clip_qkv": self.clip_qkv,
|
||||
"kv_n_heads": self.kv_n_heads,
|
||||
"model_type": self.attn_config_model_type,
|
||||
"rope_theta": self.rope_theta,
|
||||
"clip_qkv": clip_qkv,
|
||||
"model_type": attn_config_model_type,
|
||||
"rope_theta": rope_theta,
|
||||
}
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
# Behind the scenes, `DbrxConfig` maps the parameters `hidden_size`, `num_hidden_layers`,
|
||||
# `num_attention_heads`, `max_position_embeddings` to the parameters `d_model`, `n_layers`,
|
||||
# `n_heads`, `max_seq_len` respectively. We use the first group of parameters because
|
||||
# other tests expect every model to have these parameters with these specific names.
|
||||
config = DbrxConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size, # mapped to `d_model`
|
||||
num_hidden_layers=self.num_hidden_layers, # mapped to `n_layers`
|
||||
num_attention_heads=self.num_attention_heads, # mapped to `n_heads`
|
||||
max_position_embeddings=self.max_position_embeddings, # mapped to `max_seq_len`
|
||||
attn_config=self.attn_config,
|
||||
ffn_config=self.ffn_config,
|
||||
resid_pdrop=self.resid_pdrop,
|
||||
emb_pdrop=self.emb_pdrop,
|
||||
use_cache=self.use_cache,
|
||||
initializer_range=self.initializer_range,
|
||||
output_router_logits=self.output_router_logits,
|
||||
is_decoder=self.is_decoder,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
return config
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DbrxModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
@property
|
||||
def config_args(self):
|
||||
return super().config_args + ["ffn_config", "attn_config"]
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class DbrxModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-generation": DbrxForCausalLM} if is_torch_available() else {}
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DbrxModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": DbrxModel,
|
||||
"text-generation": DbrxForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
model_tester_class = DbrxModelTester
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
Loading…
Reference in New Issue
Block a user