mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix GenerationMixin.generate
compatibility with pytorch profiler (#31935)
use torch.compiler.is_compiling() when possible
This commit is contained in:
parent
7f79a97399
commit
8480fda6ee
@ -642,12 +642,8 @@ def is_torch_mlu_available(check_device=False):
|
||||
def is_torchdynamo_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
try:
|
||||
import torch._dynamo as dynamo # noqa: F401
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
return version.parse(_torch_version) >= version.parse("2.0.0")
|
||||
|
||||
|
||||
def is_torch_compile_available():
|
||||
@ -665,9 +661,15 @@ def is_torchdynamo_compiling():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
try:
|
||||
import torch._dynamo as dynamo # noqa: F401
|
||||
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) hence rather relying on `torch.compiler.is_compiling()` when possible.
|
||||
if version.parse(_torch_version) >= version.parse("2.3.0"):
|
||||
import torch
|
||||
|
||||
return dynamo.is_compiling()
|
||||
return torch.compiler.is_compiling()
|
||||
else:
|
||||
import torch._dynamo as dynamo # noqa: F401
|
||||
|
||||
return dynamo.is_compiling()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user