mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
Fix setting FLASH_ATTENTION_DETERMINISTIC after importing (#37185)
transformers.enable_full_determinism enables deterministic
flash attention using `FLASH_ATTENTION_DETERMINISTIC`
800510c67b/src/transformers/trainer_utils.py (L79)
However, current checks use a global variable `deterministic_g`,
which will do the environment variable check as soon as importing,
this will cause issues as users can call
`transformers.enable_full_determinism` after
`transformers.modeling_flash_attention_utils` is imported. This
behavior is introduced in
https://github.com/huggingface/transformers/pull/33932/files#r1806668579
to fix the graph break.
As a result, this PR implement fixes by delaying the environment variable
check to the first time when `_flash_attention_forward` is executed, so
that we can fix this issue and we won't introduce a graph break.
Signed-off-by: Hollow Man <hollowman@opensuse.org>
This commit is contained in:
parent
fde1120b6c
commit
64d14ef28d
@ -279,7 +279,7 @@ def fa_peft_integration_check(
|
||||
|
||||
|
||||
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
|
||||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
deterministic_g = None
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
@ -342,6 +342,9 @@ def _flash_attention_forward(
|
||||
|
||||
if flash_241:
|
||||
if deterministic is None:
|
||||
global deterministic_g
|
||||
if deterministic_g is None:
|
||||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
deterministic = deterministic_g
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user