NPU support SDPA (#35165)

Co-authored-by: root <weichunyude@163.com>
This commit is contained in:
zheliuyu 2025-01-07 18:30:05 +08:00 committed by GitHub
parent 02ed609285
commit ed73ae210b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -361,6 +361,9 @@ def is_torch_sdpa_available():
# NOTE: MLU is OK with non-contiguous inputs.
if is_torch_mlu_available():
return version.parse(_torch_version) >= version.parse("2.1.0")
# NOTE: NPU can use SDPA in Transformers with torch>=2.1.0.
if is_torch_npu_available():
return version.parse(_torch_version) >= version.parse("2.1.0")
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")