diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 12673a8d41a..651c4e7dcaf 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1195,9 +1195,7 @@ class StaticCache(Cache): self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads - ) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self._dtype = dtype self.num_key_value_heads = ( diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d5991cae8f9..8fa8dc46c3e 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -61,9 +61,10 @@ class PretrainedConfig(PushToHubMixin): - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate the correct object in [`~transformers.AutoConfig`]. - - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the - config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like: - [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. + - **has_no_defaults_at_init** (`bool`) -- Whether the config class can be initialized without providing input arguments. + Some configurations requires inputs to be defined at init and have no default values, usually these are composite configs, + (but not necessarily) such as [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`]. They have to be initialized from + two or more configs of type [`~transformers.PretrainedConfig`]. - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary outputs of the model during inference. - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized @@ -193,7 +194,7 @@ class PretrainedConfig(PushToHubMixin): model_type: str = "" base_config_key: str = "" sub_configs: dict[str, "PretrainedConfig"] = {} - is_composition: bool = False + has_no_defaults_at_init: bool = False attribute_map: dict[str, str] = {} base_model_tp_plan: Optional[dict[str, Any]] = None base_model_pp_plan: Optional[dict[str, tuple[list[str]]]] = None @@ -813,8 +814,8 @@ class PretrainedConfig(PushToHubMixin): # Get the default config dict (from a fresh PreTrainedConfig instance) default_config_dict = PretrainedConfig().to_dict() - # Get class-specific config dict if not part of a composition - class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.has_no_defaults_at_init else {} serializable_config_dict = {} diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 9a9cfd072b2..ccb961edea9 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -121,7 +121,7 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 036c9caa83b..00ff22c8b89 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -240,7 +240,6 @@ class BarkFineGenerationConfig(GenerationConfig): class BarkGenerationConfig(GenerationConfig): model_type = "bark" - is_composition = True # TODO (joao): nested from_dict diff --git a/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py index 8b5c62363f6..a5eff83e558 100644 --- a/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/configuration_encoder_decoder.py @@ -72,7 +72,7 @@ class EncoderDecoderConfig(PretrainedConfig): model_type = "encoder-decoder" sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} - is_composition = True + has_no_defaults_at_init = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py index 74fa20cc760..f476591b2eb 100644 --- a/src/transformers/models/llava/configuration_llava.py +++ b/src/transformers/models/llava/configuration_llava.py @@ -76,7 +76,6 @@ class LlavaConfig(PretrainedConfig): model_type = "llava" sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} - is_composition = True def __init__( self, diff --git a/src/transformers/models/mistral/configuration_mistral.py b/src/transformers/models/mistral/configuration_mistral.py index 3a237bc7343..8a5055be658 100644 --- a/src/transformers/models/mistral/configuration_mistral.py +++ b/src/transformers/models/mistral/configuration_mistral.py @@ -143,7 +143,7 @@ class MistralConfig(PretrainedConfig): self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window - self.head_dim = head_dim or hidden_size // num_attention_heads + self.head_dim = head_dim # for backward compatibility if num_key_value_heads is None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 494f94dd348..5639c3bbb6d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -139,7 +139,7 @@ class MistralAttention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 6026f5a7e07..4f36181cd72 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -42,6 +42,7 @@ class MistralMLP(LlamaMLP): class MistralAttention(LlamaAttention): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index d9b02e10fc4..4f11077c194 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -172,7 +172,7 @@ class MixtralConfig(PretrainedConfig): self.use_cache = use_cache self.rope_theta = rope_theta self.attention_dropout = attention_dropout - self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.head_dim = head_dim self.num_experts_per_tok = num_experts_per_tok self.num_local_experts = num_local_experts diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9f839536b7f..51604dec3f9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -251,7 +251,7 @@ class MixtralAttention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py index 6c38caf20dc..e5079eb1edb 100644 --- a/src/transformers/models/musicgen/configuration_musicgen.py +++ b/src/transformers/models/musicgen/configuration_musicgen.py @@ -195,7 +195,7 @@ class MusicgenConfig(PretrainedConfig): "audio_encoder": AutoConfig, "decoder": MusicgenDecoderConfig, } - is_composition = True + has_no_defaults_at_init = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py b/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py index e65ad50021c..e35c7bd3c8a 100644 --- a/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py @@ -201,7 +201,7 @@ class MusicgenMelodyConfig(PretrainedConfig): "audio_encoder": AutoConfig, "decoder": MusicgenMelodyDecoderConfig, } - is_composition = True + has_no_defaults_at_init = True def __init__( self, diff --git a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py index ee1f1e9eb5e..da5af9fb344 100644 --- a/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py +++ b/src/transformers/models/pixtral/convert_pixtral_weights_to_hf.py @@ -158,7 +158,6 @@ def convert_mistral_model(input_dir, output_dir): hidden_act="silu", sliding_window=None, tie_word_embeddings=False, - is_composition=True, rms_norm_eps=1e-5, ) else: diff --git a/src/transformers/models/rag/configuration_rag.py b/src/transformers/models/rag/configuration_rag.py index c76926f2187..fd0c9bb9cce 100644 --- a/src/transformers/models/rag/configuration_rag.py +++ b/src/transformers/models/rag/configuration_rag.py @@ -79,7 +79,7 @@ RAG_CONFIG_DOC = r""" @add_start_docstrings(RAG_CONFIG_DOC) class RagConfig(PretrainedConfig): model_type = "rag" - is_composition = True + has_no_defaults_at_init = True def __init__( self, diff --git a/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py index 47312df27ea..14dfab7eaa6 100644 --- a/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py @@ -72,7 +72,7 @@ class SpeechEncoderDecoderConfig(PretrainedConfig): model_type = "speech-encoder-decoder" sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} - is_composition = True + has_no_defaults_at_init = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ba4ff8098bd..089cb00dceb 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -158,7 +158,7 @@ class Starcoder2Attention(nn.Module): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout diff --git a/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py index 235069ea5a8..09b324af247 100644 --- a/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py @@ -79,7 +79,7 @@ class VisionEncoderDecoderConfig(PretrainedConfig): model_type = "vision-encoder-decoder" sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} - is_composition = True + has_no_defaults_at_init = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py index 908e1bf1843..3f544e9eaf0 100644 --- a/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py @@ -76,7 +76,7 @@ class VisionTextDualEncoderConfig(PretrainedConfig): model_type = "vision-text-dual-encoder" sub_configs = {"vision_config": AutoConfig, "text_config": AutoConfig} - is_composition = True + has_no_defaults_at_init = True def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs): super().__init__(**kwargs) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e16d477335d..68539b71e16 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1755,9 +1755,7 @@ class GenerationTesterMixin: text_config = model.config.get_text_config() head_dim = ( - text_config.head_dim - if hasattr(text_config, "head_dim") - else text_config.hidden_size // text_config.num_attention_heads + getattr(text_config, "head_dim", None) or text_config.hidden_size // text_config.num_attention_heads ) num_key_value_heads = ( text_config.num_attention_heads @@ -2008,9 +2006,8 @@ class GenerationTesterMixin: max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache text_config = config.text_config if hasattr(config, "text_config") else config head_dim = ( - text_config.head_dim - if hasattr(text_config, "head_dim") - else text_config.hidden_size // text_config.num_attention_heads + getattr(text_config, "head_dim", None) + or text_config.hidden_size // text_config.num_attention_heads ) num_key_value_heads = ( text_config.num_attention_heads diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 4747acf7519..f329d1b2110 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -184,13 +184,6 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM ) def test_config(self): - # overwritten from `tests/test_configuration_common.py::ConfigTester` after #36077 - # TODO: avoid overwritten once there is a better fix for #36077 - def check_config_can_be_init_without_params(): - config = self.config_tester.config_class() - self.config_tester.parent.assertIsNotNone(config) - - self.config_tester.check_config_can_be_init_without_params = check_config_can_be_init_without_params self.config_tester.run_common_tests() # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 4a3871a5c96..4d4ce3a3f16 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -163,7 +163,7 @@ class ConfigTester: self.parent.assertEqual(len(config.label2id), 3) def check_config_can_be_init_without_params(self): - if self.config_class.is_composition: + if self.config_class.has_no_defaults_at_init: with self.parent.assertRaises(ValueError): config = self.config_class() else: