protect dtensor import (#38496)

protect
This commit is contained in:
Marc Sun 2025-05-30 17:36:00 +02:00 committed by GitHub
parent 051a8acc9a
commit c7f2b79dd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1 additions and 5 deletions

View File

@ -37,7 +37,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVa
from zipfile import is_zipfile
import torch
import torch.distributed.tensor
from huggingface_hub import split_torch_state_dict_into_shards
from packaging import version
from torch import Tensor, nn

View File

@ -42,9 +42,6 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
pass
def softmax_backward_data(parent, grad_output, output, dim, self):
"""
@ -296,7 +293,7 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
if is_torch_greater_or_equal_than_2_0:
if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor
if isinstance(tensor, DTensor):