Don't use default attn if pre-set in sub-config (#38526)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
New model PR merged notification / Notify new model (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* don't use default attn if pre-set in sib-config

* style

* add a test maybe
This commit is contained in:
Raushan Turganbay 2025-06-03 09:53:07 +02:00 committed by GitHub
parent bf68dd9e6e
commit 55ec319de6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 2 deletions

View File

@ -2277,8 +2277,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
# For models with backbone sub-config might be not initialized
if sub_config is not None:
# For models with backbone sub-config might be not initialized. Set the requested att
# if the config hasn't got any attn pre-set and the requested attn in not `None` (i.e not the default attn)
if (
sub_config is not None
and sub_config._attn_implementation_internal is None
and curr_attn_implementation is not None
):
sub_config._attn_implementation_internal = curr_attn_implementation
if config._attn_implementation == "flash_attention_2":

View File

@ -3425,6 +3425,13 @@ class ModelTesterMixin:
f"The eager model should not have SDPA/FA2 attention layers but got `{class_name}.config._attn_implementation={submodule.config._attn_implementation}`"
)
# Set the attention to default `None` but the text config to `eager`
# The model should load encoders in SDPA but not the text attention
config._attn_implementation = None
config.get_text_config(decoder=True)._attn_implementation = "eager"
model = model_class(config)
self.assertTrue(model.config.get_text_config(decoder=True)._attn_implementation == "eager")
@require_torch_sdpa
def test_sdpa_can_dispatch_non_composite_models(self):
"""