Disable the FA backend for SDPA on AMD GPUs (#30850)

* disable fa

* disable fa

* update warning

* update warning
This commit is contained in:
Mohit Sharma 2024-05-16 17:01:14 +05:30 committed by GitHub
parent 9d889f870e
commit 0753134f4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1479,6 +1479,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config,
hard_check_only=False if requested_attn_implementation is None else True,
)
if (
torch.version.hip is not None
and config._attn_implementation == "sdpa"
and torch.cuda.device_count() > 1
):
logger.warning_once(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
else:
config._attn_implementation = "eager"