diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 480e6f3f3f3..6e2d7a8102f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2277,8 +2277,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi if not isinstance(requested_attn_implementation, dict) else requested_attn_implementation.get(key, None) ) - # For models with backbone sub-config might be not initialized - if sub_config is not None: + # For models with backbone sub-config might be not initialized. Set the requested att + # if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn) + if ( + sub_config is not None + and sub_config._attn_implementation_internal is None + and curr_attn_implementation is not None + ): sub_config._attn_implementation_internal = curr_attn_implementation if config._attn_implementation == "flash_attention_2": diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 441c99267b6..446610db966 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3425,6 +3425,13 @@ class ModelTesterMixin: f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`" ) + # Set the attention to default `None` but the text config to `eager` + # The model should load encoders in SDPA but not the text attention + config._attn_implementation = None + config.get_text_config(decoder=True)._attn_implementation = "eager" + model = model_class(config) + self.assertTrue(model.config.get_text_config(decoder=True)._attn_implementation == "eager") + @require_torch_sdpa def test_sdpa_can_dispatch_non_composite_models(self): """