mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Disable the FA backend for SDPA on AMD GPUs (#30850)
* disable fa * disable fa * update warning * update warning
This commit is contained in:
parent
9d889f870e
commit
0753134f4d
@ -1479,6 +1479,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
config,
|
||||
hard_check_only=False if requested_attn_implementation is None else True,
|
||||
)
|
||||
|
||||
if (
|
||||
torch.version.hip is not None
|
||||
and config._attn_implementation == "sdpa"
|
||||
and torch.cuda.device_count() > 1
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
|
||||
)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
else:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user