Update composition flag usage (#36263)

* update composition flag usage

* remove print

* fix tests

* actually fix

* oh c'mon

* now should be fixed right?

* fix copies
This commit is contained in:
Raushan Turganbay 2025-04-09 11:48:49 +02:00 committed by GitHub
parent 08e3217baf
commit 6f4058aee3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 26 additions and 39 deletions

View File

@ -1195,9 +1195,7 @@ class StaticCache(Cache):
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = ( self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self._dtype = dtype self._dtype = dtype
self.num_key_value_heads = ( self.num_key_value_heads = (

View File

@ -61,9 +61,10 @@ class PretrainedConfig(PushToHubMixin):
- **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
the correct object in [`~transformers.AutoConfig`]. the correct object in [`~transformers.AutoConfig`].
- **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the - **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments.
config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like: Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs,
[`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. (but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from
two or more configs of type [`~transformers.PretrainedConfig`].
- **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
outputs of the model during inference. outputs of the model during inference.
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
@ -193,7 +194,7 @@ class PretrainedConfig(PushToHubMixin):
model_type: str = "" model_type: str = ""
base_config_key: str = "" base_config_key: str = ""
sub_configs: dict[str, "PretrainedConfig"] = {} sub_configs: dict[str, "PretrainedConfig"] = {}
is_composition: bool = False has_no_defaults_at_init: bool = False
attribute_map: dict[str, str] = {} attribute_map: dict[str, str] = {}
base_model_tp_plan: Optional[dict[str, Any]] = None base_model_tp_plan: Optional[dict[str, Any]] = None
base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None
@ -813,8 +814,8 @@ class PretrainedConfig(PushToHubMixin):
# Get the default config dict (from a fresh PreTrainedConfig instance) # Get the default config dict (from a fresh PreTrainedConfig instance)
default_config_dict = PretrainedConfig().to_dict() default_config_dict = PretrainedConfig().to_dict()
# Get class-specific config dict if not part of a composition # get class specific config dict
class_config_dict = self.__class__().to_dict() if not self.is_composition else {} class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {}
serializable_config_dict = {} serializable_config_dict = {}

View File

@ -121,7 +121,7 @@ def _compute_default_rope_parameters(
elif config is not None: elif config is not None:
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor) dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE attention_factor = 1.0 # Unused in this type of RoPE

View File

@ -240,7 +240,6 @@ class BarkFineGenerationConfig(GenerationConfig):
class BarkGenerationConfig(GenerationConfig): class BarkGenerationConfig(GenerationConfig):
model_type = "bark" model_type = "bark"
is_composition = True
# TODO (joao): nested from_dict # TODO (joao): nested from_dict

View File

@ -72,7 +72,7 @@ class EncoderDecoderConfig(PretrainedConfig):
model_type = "encoder-decoder" model_type = "encoder-decoder"
sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
is_composition = True has_no_defaults_at_init = True
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -76,7 +76,6 @@ class LlavaConfig(PretrainedConfig):
model_type = "llava" model_type = "llava"
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
is_composition = True
def __init__( def __init__(
self, self,

View File

@ -143,7 +143,7 @@ class MistralConfig(PretrainedConfig):
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.head_dim = head_dim or hidden_size // num_attention_heads self.head_dim = head_dim
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:

View File

@ -139,7 +139,7 @@ class MistralAttention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout

View File

@ -42,6 +42,7 @@ class MistralMLP(LlamaMLP):
class MistralAttention(LlamaAttention): class MistralAttention(LlamaAttention):
def __init__(self, config: MistralConfig, layer_idx: int): def __init__(self, config: MistralConfig, layer_idx: int):
super().__init__() super().__init__()
self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)

View File

@ -172,7 +172,7 @@ class MixtralConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.head_dim = head_dim
self.num_experts_per_tok = num_experts_per_tok self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts self.num_local_experts = num_local_experts

View File

@ -251,7 +251,7 @@ class MixtralAttention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout

View File

@ -195,7 +195,7 @@ class MusicgenConfig(PretrainedConfig):
"audio_encoder": AutoConfig, "audio_encoder": AutoConfig,
"decoder": MusicgenDecoderConfig, "decoder": MusicgenDecoderConfig,
} }
is_composition = True has_no_defaults_at_init = True
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -201,7 +201,7 @@ class MusicgenMelodyConfig(PretrainedConfig):
"audio_encoder": AutoConfig, "audio_encoder": AutoConfig,
"decoder": MusicgenMelodyDecoderConfig, "decoder": MusicgenMelodyDecoderConfig,
} }
is_composition = True has_no_defaults_at_init = True
def __init__( def __init__(
self, self,

View File

@ -158,7 +158,6 @@ def convert_mistral_model(input_dir, output_dir):
hidden_act="silu", hidden_act="silu",
sliding_window=None, sliding_window=None,
tie_word_embeddings=False, tie_word_embeddings=False,
is_composition=True,
rms_norm_eps=1e-5, rms_norm_eps=1e-5,
) )
else: else:

View File

@ -79,7 +79,7 @@ RAG_CONFIG_DOC = r"""
@add_start_docstrings(RAG_CONFIG_DOC) @add_start_docstrings(RAG_CONFIG_DOC)
class RagConfig(PretrainedConfig): class RagConfig(PretrainedConfig):
model_type = "rag" model_type = "rag"
is_composition = True has_no_defaults_at_init = True
def __init__( def __init__(
self, self,

View File

@ -72,7 +72,7 @@ class SpeechEncoderDecoderConfig(PretrainedConfig):
model_type = "speech-encoder-decoder" model_type = "speech-encoder-decoder"
sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
is_composition = True has_no_defaults_at_init = True
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -158,7 +158,7 @@ class Starcoder2Attention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout

View File

@ -79,7 +79,7 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
model_type = "vision-encoder-decoder" model_type = "vision-encoder-decoder"
sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
is_composition = True has_no_defaults_at_init = True
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -76,7 +76,7 @@ class VisionTextDualEncoderConfig(PretrainedConfig):
model_type = "vision-text-dual-encoder" model_type = "vision-text-dual-encoder"
sub_configs = {"vision_config": AutoConfig, "text_config": AutoConfig} sub_configs = {"vision_config": AutoConfig, "text_config": AutoConfig}
is_composition = True has_no_defaults_at_init = True
def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs): def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@ -1755,9 +1755,7 @@ class GenerationTesterMixin:
text_config = model.config.get_text_config() text_config = model.config.get_text_config()
head_dim = ( head_dim = (
text_config.head_dim getattr(text_config, "head_dim", None) or text_config.hidden_size // text_config.num_attention_heads
if hasattr(text_config, "head_dim")
else text_config.hidden_size // text_config.num_attention_heads
) )
num_key_value_heads = ( num_key_value_heads = (
text_config.num_attention_heads text_config.num_attention_heads
@ -2008,9 +2006,8 @@ class GenerationTesterMixin:
max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache
text_config = config.text_config if hasattr(config, "text_config") else config text_config = config.text_config if hasattr(config, "text_config") else config
head_dim = ( head_dim = (
text_config.head_dim getattr(text_config, "head_dim", None)
if hasattr(text_config, "head_dim") or text_config.hidden_size // text_config.num_attention_heads
else text_config.hidden_size // text_config.num_attention_heads
) )
num_key_value_heads = ( num_key_value_heads = (
text_config.num_attention_heads text_config.num_attention_heads

View File

@ -184,13 +184,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM
) )
def test_config(self): def test_config(self):
# overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077
# TODO: avoid overwritten once there is a better fix for #36077
def check_config_can_be_init_without_params():
config = self.config_tester.config_class()
self.config_tester.parent.assertIsNotNone(config)
self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs

View File

@ -163,7 +163,7 @@ class ConfigTester:
self.parent.assertEqual(len(config.label2id), 3) self.parent.assertEqual(len(config.label2id), 3)
def check_config_can_be_init_without_params(self): def check_config_can_be_init_without_params(self):
if self.config_class.is_composition: if self.config_class.has_no_defaults_at_init:
with self.parent.assertRaises(ValueError): with self.parent.assertRaises(ValueError):
config = self.config_class() config = self.config_class()
else: else: