mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
eed11f34ab
commit
dadb286f06
@ -52,6 +52,7 @@ from .pytorch_utils import ( # noqa: F401
|
||||
find_pruneable_heads_and_indices,
|
||||
id_tensor_storage,
|
||||
is_torch_greater_or_equal_than_1_13,
|
||||
is_torch_greater_or_equal_than_2_4,
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
@ -5005,6 +5006,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_mesh (`torch.distributed.DeviceMesh`):
|
||||
The device mesh to use for tensor parallelism.
|
||||
"""
|
||||
if not is_torch_greater_or_equal_than_2_4:
|
||||
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||
|
||||
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
||||
# No op if `_tp_plan` attribute does not exist under the module.
|
||||
|
@ -20,11 +20,6 @@ import torch
|
||||
from packaging import version
|
||||
from safetensors.torch import storage_ptr, storage_size
|
||||
from torch import nn
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
RowwiseParallel,
|
||||
)
|
||||
|
||||
from .utils import is_torch_xla_available, logging
|
||||
|
||||
@ -44,6 +39,14 @@ is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse
|
||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||
|
||||
|
||||
if is_torch_greater_or_equal_than_2_4:
|
||||
from torch.distributed.tensor import Replicate
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
RowwiseParallel,
|
||||
)
|
||||
|
||||
|
||||
def softmax_backward_data(parent, grad_output, output, dim, self):
|
||||
"""
|
||||
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
|
||||
|
Loading…
Reference in New Issue
Block a user