Fix GenerationMixin.generate compatibility with pytorch profiler (#31935)

use torch.compiler.is_compiling() when possible
This commit is contained in:
fxmarty 2024-07-14 15:44:38 +02:00 committed by GitHub
parent 7f79a97399
commit 8480fda6ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False):
def is_torchdynamo_available():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
return True
except Exception:
return False
return version.parse(_torch_version) >= version.parse("2.0.0")
def is_torch_compile_available():
@ -665,9 +661,15 @@ def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
if version.parse(_torch_version) >= version.parse("2.3.0"):
import torch
return dynamo.is_compiling()
return torch.compiler.is_compiling()
else:
import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling()
except Exception:
return False