fix: default value reflects the runtime environment variables rather than the ones present at import time. (#32153)

* fix: default value reflects the runtime environment variables rather than the ones present at import time.

* Fix: Change `deterministic` to None by default; use env var if None
This commit is contained in:
조준래 2024-07-24 19:38:49 +09:00 committed by GitHub
parent 01be5b4879
commit 8678879f1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -193,7 +193,7 @@ def _flash_attention_forward(
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1",
deterministic: bool = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@ -233,6 +233,8 @@ def _flash_attention_forward(
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic
if softcap is not None: