[PyTorch/XLA] Fix extra TPU compilations introduced by recent changes (#29158)

* tmp

* Remove debug step

* Fix a typo

* Move to is_torch_xla_available
This commit is contained in:
Jiewen Tan 2024-03-13 08:30:32 -07:00 committed by GitHub
parent 1e21c4fbe0
commit b340d90738
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1364,7 +1364,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hard_check_only=False,
check_device_map=check_device_map,
)
elif requested_attn_implementation in [None, "sdpa"]:
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
config,