mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Refactor DBRX tests to use CausalLMModelTest base classes (#38475)
* Refactor DBRX tests to use CausalLMModelTest base classes - Changed DbrxModelTester to inherit from CausalLMModelTester - Changed DbrxModelTest to inherit from CausalLMModelTest - Removed duplicate methods that are already in base classes - Added required class attributes for model classes - Updated pipeline_model_mapping to include feature-extraction - Kept DBRX-specific configuration and test methods - Disabled RoPE tests as DBRX's rotary embedding doesn't accept config parameter This refactoring reduces code duplication and follows the pattern established in other causal LM model tests like Gemma. * Apply style fixes * Trigger tests * Refactor DBRX test * Make sure the DBRX-specific settings are handled * Use the attribute_map * Fix attribute map --------- Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
64041694a8
commit
b82a45b3b4
@ -181,11 +181,18 @@ class CausalLMModelTester:
|
|||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config_args(self):
|
||||||
|
return list(signature(self.config_class.__init__).parameters.keys())
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
kwarg_names = list(signature(self.config_class.__init__).parameters.keys())
|
kwargs = {}
|
||||||
kwargs = {
|
model_name_to_common_name = {v: k for k, v in self.config_class.attribute_map.items()}
|
||||||
k: getattr(self, k) for k in kwarg_names + self.forced_config_args if hasattr(self, k) and k != "self"
|
for k in self.config_args + self.forced_config_args:
|
||||||
}
|
if hasattr(self, k) and k != "self":
|
||||||
|
kwargs[k] = getattr(self, k)
|
||||||
|
elif k in model_name_to_common_name and hasattr(self, model_name_to_common_name[k]):
|
||||||
|
kwargs[k] = getattr(self, model_name_to_common_name[k])
|
||||||
return self.config_class(**kwargs)
|
return self.config_class(**kwargs)
|
||||||
|
|
||||||
def create_and_check_model(
|
def create_and_check_model(
|
||||||
|
@ -16,12 +16,9 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import DbrxConfig, is_torch_available
|
from transformers import DbrxConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester
|
||||||
from ...test_configuration_common import ConfigTester
|
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
|
||||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -30,197 +27,74 @@ if is_torch_available():
|
|||||||
from transformers import DbrxForCausalLM, DbrxModel
|
from transformers import DbrxForCausalLM, DbrxModel
|
||||||
|
|
||||||
|
|
||||||
class DbrxModelTester:
|
class DbrxModelTester(CausalLMModelTester):
|
||||||
|
config_class = DbrxConfig
|
||||||
|
if is_torch_available():
|
||||||
|
base_model_class = DbrxModel
|
||||||
|
causal_lm_class = DbrxForCausalLM
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
hidden_size=32,
|
|
||||||
ffn_hidden_size=32,
|
|
||||||
num_attention_heads=4,
|
|
||||||
kv_n_heads=4,
|
|
||||||
num_hidden_layers=5,
|
|
||||||
max_position_embeddings=512,
|
|
||||||
type_vocab_size=16,
|
|
||||||
batch_size=13,
|
|
||||||
seq_length=7,
|
|
||||||
is_training=True,
|
|
||||||
use_input_mask=True,
|
|
||||||
use_token_type_ids=False,
|
|
||||||
use_labels=True,
|
|
||||||
use_cache=True,
|
|
||||||
type_sequence_label_size=2,
|
|
||||||
num_labels=3,
|
|
||||||
num_choices=4,
|
|
||||||
scope=None,
|
|
||||||
clip_qkv=8,
|
clip_qkv=8,
|
||||||
rope_theta=500000,
|
rope_theta=500000,
|
||||||
attn_config_model_type="",
|
attn_config_model_type="",
|
||||||
emb_pdrop=0.0,
|
|
||||||
moe_jitter_eps=0,
|
moe_jitter_eps=0,
|
||||||
moe_loss_weight=0.05,
|
moe_loss_weight=0.05,
|
||||||
moe_num_experts=16,
|
moe_num_experts=8,
|
||||||
moe_top_k=4,
|
moe_top_k=4,
|
||||||
ffn_config_model_type="",
|
ffn_config_model_type="",
|
||||||
ffn_act_fn_name="gelu",
|
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
output_router_logits=False,
|
|
||||||
resid_pdrop=0.0,
|
resid_pdrop=0.0,
|
||||||
tie_word_embeddings=False,
|
|
||||||
torch_dtype="bfloat16",
|
|
||||||
vocab_size=99,
|
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
):
|
):
|
||||||
# Parameters unique to testing
|
# Call parent init
|
||||||
self.batch_size = batch_size
|
super().__init__(
|
||||||
self.seq_length = seq_length
|
parent=parent,
|
||||||
self.type_vocab_size = type_vocab_size
|
hidden_dropout_prob=resid_pdrop,
|
||||||
self.type_sequence_label_size = type_sequence_label_size
|
attention_probs_dropout_prob=resid_pdrop,
|
||||||
self.num_labels = num_labels
|
initializer_range=initializer_range,
|
||||||
self.num_choices = num_choices
|
pad_token_id=pad_token_id,
|
||||||
self.scope = scope
|
is_decoder=is_decoder,
|
||||||
self.parent = parent
|
)
|
||||||
self.is_training = is_training
|
|
||||||
self.use_input_mask = use_input_mask
|
|
||||||
self.use_token_type_ids = use_token_type_ids
|
|
||||||
self.use_labels = use_labels
|
|
||||||
|
|
||||||
# attn_config params
|
# Set DBRX's unusual params
|
||||||
self.clip_qkv = clip_qkv
|
self.clip_qkv = clip_qkv
|
||||||
self.kv_n_heads = kv_n_heads
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.attn_config_model_type = attn_config_model_type
|
|
||||||
|
|
||||||
# ffn_config params
|
# DBRX takes sub-configurations for the FFN and attention layers, so we need to set that correctly here
|
||||||
self.ffn_hidden_size = ffn_hidden_size
|
|
||||||
self.moe_jitter_eps = moe_jitter_eps
|
|
||||||
self.moe_loss_weight = moe_loss_weight
|
|
||||||
self.moe_num_experts = moe_num_experts
|
|
||||||
self.moe_top_k = moe_top_k
|
|
||||||
self.ffn_config_model_type = ffn_config_model_type
|
|
||||||
self.ffn_act_fn_name = ffn_act_fn_name
|
|
||||||
|
|
||||||
# Other model params
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.num_hidden_layers = num_hidden_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.emb_pdrop = emb_pdrop
|
|
||||||
self.output_router_logits = output_router_logits
|
|
||||||
self.resid_pdrop = resid_pdrop
|
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
self.is_decoder = is_decoder
|
|
||||||
self.pad_token_id = pad_token_id
|
|
||||||
|
|
||||||
# Make the dictionaries
|
|
||||||
self.ffn_config = {
|
self.ffn_config = {
|
||||||
"ffn_hidden_size": self.ffn_hidden_size,
|
"ffn_hidden_size": self.hidden_size,
|
||||||
"moe_jitter_eps": self.moe_jitter_eps,
|
"moe_jitter_eps": moe_jitter_eps,
|
||||||
"moe_loss_weight": self.moe_loss_weight,
|
"moe_loss_weight": moe_loss_weight,
|
||||||
"moe_num_experts": self.moe_num_experts,
|
"moe_num_experts": moe_num_experts,
|
||||||
"moe_top_k": self.moe_top_k,
|
"moe_top_k": moe_top_k,
|
||||||
"model_type": self.ffn_config_model_type,
|
"model_type": ffn_config_model_type,
|
||||||
"ffn_act_fn": {"name": self.ffn_act_fn_name},
|
"ffn_act_fn": {"name": self.hidden_act},
|
||||||
}
|
}
|
||||||
self.attn_config = {
|
self.attn_config = {
|
||||||
"clip_qkv": self.clip_qkv,
|
"clip_qkv": clip_qkv,
|
||||||
"kv_n_heads": self.kv_n_heads,
|
"model_type": attn_config_model_type,
|
||||||
"model_type": self.attn_config_model_type,
|
"rope_theta": rope_theta,
|
||||||
"rope_theta": self.rope_theta,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
@property
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
def config_args(self):
|
||||||
|
return super().config_args + ["ffn_config", "attn_config"]
|
||||||
input_mask = None
|
|
||||||
if self.use_input_mask:
|
|
||||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
||||||
|
|
||||||
token_type_ids = None
|
|
||||||
if self.use_token_type_ids:
|
|
||||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
|
||||||
|
|
||||||
sequence_labels = None
|
|
||||||
token_labels = None
|
|
||||||
choice_labels = None
|
|
||||||
if self.use_labels:
|
|
||||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
|
||||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
|
||||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
|
||||||
|
|
||||||
config = self.get_config()
|
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
# Behind the scenes, `DbrxConfig` maps the parameters `hidden_size`, `num_hidden_layers`,
|
|
||||||
# `num_attention_heads`, `max_position_embeddings` to the parameters `d_model`, `n_layers`,
|
|
||||||
# `n_heads`, `max_seq_len` respectively. We use the first group of parameters because
|
|
||||||
# other tests expect every model to have these parameters with these specific names.
|
|
||||||
config = DbrxConfig(
|
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
hidden_size=self.hidden_size, # mapped to `d_model`
|
|
||||||
num_hidden_layers=self.num_hidden_layers, # mapped to `n_layers`
|
|
||||||
num_attention_heads=self.num_attention_heads, # mapped to `n_heads`
|
|
||||||
max_position_embeddings=self.max_position_embeddings, # mapped to `max_seq_len`
|
|
||||||
attn_config=self.attn_config,
|
|
||||||
ffn_config=self.ffn_config,
|
|
||||||
resid_pdrop=self.resid_pdrop,
|
|
||||||
emb_pdrop=self.emb_pdrop,
|
|
||||||
use_cache=self.use_cache,
|
|
||||||
initializer_range=self.initializer_range,
|
|
||||||
output_router_logits=self.output_router_logits,
|
|
||||||
is_decoder=self.is_decoder,
|
|
||||||
pad_token_id=self.pad_token_id,
|
|
||||||
)
|
|
||||||
return config
|
|
||||||
|
|
||||||
def create_and_check_model(
|
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
|
||||||
):
|
|
||||||
model = DbrxModel(config=config)
|
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
result = model(input_ids, attention_mask=input_mask)
|
|
||||||
result = model(input_ids)
|
|
||||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
|
||||||
(
|
|
||||||
config,
|
|
||||||
input_ids,
|
|
||||||
token_type_ids,
|
|
||||||
input_mask,
|
|
||||||
sequence_labels,
|
|
||||||
token_labels,
|
|
||||||
choice_labels,
|
|
||||||
) = config_and_inputs
|
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
||||||
return config, inputs_dict
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class DbrxModelTest(CausalLMModelTest, unittest.TestCase):
|
||||||
all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else ()
|
all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else ()
|
||||||
pipeline_model_mapping = {"text-generation": DbrxForCausalLM} if is_torch_available() else {}
|
pipeline_model_mapping = (
|
||||||
test_headmasking = False
|
{
|
||||||
test_pruning = False
|
"feature-extraction": DbrxModel,
|
||||||
|
"text-generation": DbrxForCausalLM,
|
||||||
def setUp(self):
|
}
|
||||||
self.model_tester = DbrxModelTester(self)
|
if is_torch_available()
|
||||||
self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37)
|
else {}
|
||||||
|
)
|
||||||
def test_config(self):
|
model_tester_class = DbrxModelTester
|
||||||
self.config_tester.run_common_tests()
|
|
||||||
|
|
||||||
def test_model(self):
|
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
||||||
|
|
||||||
def test_model_various_embeddings(self):
|
def test_model_various_embeddings(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
Loading…
Reference in New Issue
Block a user