diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 2f00d9b6c0e..648995eed7b 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -279,7 +279,7 @@ def fa_peft_integration_check( flash_241 = is_flash_attn_greater_or_equal("2.4.1") -deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" +deterministic_g = None def _flash_attention_forward( @@ -342,6 +342,9 @@ def _flash_attention_forward( if flash_241: if deterministic is None: + global deterministic_g + if deterministic_g is None: + deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" deterministic = deterministic_g flash_kwargs["deterministic"] = deterministic