Fix setting FLASH_ATTENTION_DETERMINISTIC after importing (#37185)

transformers.enable_full_determinism enables deterministic
flash attention using `FLASH_ATTENTION_DETERMINISTIC`
800510c67b/src/transformers/trainer_utils.py (L79)

However, current checks use a global variable `deterministic_g`,
which will do the environment variable check as soon as importing,
this will cause issues as users can call
`transformers.enable_full_determinism` after
`transformers.modeling_flash_attention_utils` is imported. This
behavior is introduced in
https://github.com/huggingface/transformers/pull/33932/files#r1806668579
to fix the graph break.

As a result, this PR implement fixes by delaying the environment variable
check to the first time when `_flash_attention_forward` is executed, so
that we can fix this issue and we won't introduce a graph break.

Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 2025-06-02 17:08:20 +08:00 committed by GitHub
parent fde1120b6c
commit 64d14ef28d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -279,7 +279,7 @@ def fa_peft_integration_check(
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic_g = None
def _flash_attention_forward(
@ -342,6 +342,9 @@ def _flash_attention_forward(
if flash_241:
if deterministic is None:
global deterministic_g
if deterministic_g is None:
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic