mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
parent
eed11f34ab
commit
dadb286f06
@ -52,6 +52,7 @@ from .pytorch_utils import ( # noqa: F401
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
id_tensor_storage,
|
id_tensor_storage,
|
||||||
is_torch_greater_or_equal_than_1_13,
|
is_torch_greater_or_equal_than_1_13,
|
||||||
|
is_torch_greater_or_equal_than_2_4,
|
||||||
prune_conv1d_layer,
|
prune_conv1d_layer,
|
||||||
prune_layer,
|
prune_layer,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
@ -5005,6 +5006,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
device_mesh (`torch.distributed.DeviceMesh`):
|
device_mesh (`torch.distributed.DeviceMesh`):
|
||||||
The device mesh to use for tensor parallelism.
|
The device mesh to use for tensor parallelism.
|
||||||
"""
|
"""
|
||||||
|
if not is_torch_greater_or_equal_than_2_4:
|
||||||
|
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
|
||||||
|
|
||||||
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
|
||||||
# No op if `_tp_plan` attribute does not exist under the module.
|
# No op if `_tp_plan` attribute does not exist under the module.
|
||||||
|
@ -20,11 +20,6 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from safetensors.torch import storage_ptr, storage_size
|
from safetensors.torch import storage_ptr, storage_size
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed.tensor import Replicate
|
|
||||||
from torch.distributed.tensor.parallel import (
|
|
||||||
ColwiseParallel,
|
|
||||||
RowwiseParallel,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .utils import is_torch_xla_available, logging
|
from .utils import is_torch_xla_available, logging
|
||||||
|
|
||||||
@ -44,6 +39,14 @@ is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse
|
|||||||
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_greater_or_equal_than_2_4:
|
||||||
|
from torch.distributed.tensor import Replicate
|
||||||
|
from torch.distributed.tensor.parallel import (
|
||||||
|
ColwiseParallel,
|
||||||
|
RowwiseParallel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def softmax_backward_data(parent, grad_output, output, dim, self):
|
def softmax_backward_data(parent, grad_output, output, dim, self):
|
||||||
"""
|
"""
|
||||||
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
|
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
|
||||||
|
Loading…
Reference in New Issue
Block a user