From 64d14ef28dac918e13d833d5a76280ae6c998d35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=8D=F0=9D=95=A0=F0=9D=95=9D=F0=9D=95=9D=F0=9D=95=A0?= =?UTF-8?q?=F0=9D=95=A8=20=F0=9D=95=84=F0=9D=95=92=F0=9D=95=9F?= Date: Mon, 2 Jun 2025 17:08:20 +0800 Subject: [PATCH] Fix setting FLASH_ATTENTION_DETERMINISTIC after importing (#37185) transformers.enable_full_determinism enables deterministic flash attention using `FLASH_ATTENTION_DETERMINISTIC` https://github.com/huggingface/transformers/blob/800510c67bfc5cedd0bb7635648a07f39719be43/src/transformers/trainer_utils.py#L79 However, current checks use a global variable `deterministic_g`, which will do the environment variable check as soon as importing, this will cause issues as users can call `transformers.enable_full_determinism` after `transformers.modeling_flash_attention_utils` is imported. This behavior is introduced in https://github.com/huggingface/transformers/pull/33932/files#r1806668579 to fix the graph break. As a result, this PR implement fixes by delaying the environment variable check to the first time when `_flash_attention_forward` is executed, so that we can fix this issue and we won't introduce a graph break. Signed-off-by: Hollow Man --- src/transformers/modeling_flash_attention_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 2f00d9b6c0e..648995eed7b 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -279,7 +279,7 @@ def fa_peft_integration_check( flash_241 = is_flash_attn_greater_or_equal("2.4.1") -deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" +deterministic_g = None def _flash_attention_forward( @@ -342,6 +342,9 @@ def _flash_attention_forward( if flash_241: if deterministic is None: + global deterministic_g + if deterministic_g is None: + deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" deterministic = deterministic_g flash_kwargs["deterministic"] = deterministic