protect tensor parallel usage (#34800)

protect
This commit is contained in:
Arthur 2024-11-19 09:54:11 +01:00 committed by GitHub
parent eed11f34ab
commit dadb286f06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 5 deletions

View File

@ -52,6 +52,7 @@ from .pytorch_utils import ( # noqa: F401
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
id_tensor_storage, id_tensor_storage,
is_torch_greater_or_equal_than_1_13, is_torch_greater_or_equal_than_1_13,
is_torch_greater_or_equal_than_2_4,
prune_conv1d_layer, prune_conv1d_layer,
prune_layer, prune_layer,
prune_linear_layer, prune_linear_layer,
@ -5005,6 +5006,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_mesh (`torch.distributed.DeviceMesh`): device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism. The device mesh to use for tensor parallelism.
""" """
if not is_torch_greater_or_equal_than_2_4:
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module. # Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module. # No op if `_tp_plan` attribute does not exist under the module.

View File

@ -20,11 +20,6 @@ import torch
from packaging import version from packaging import version
from safetensors.torch import storage_ptr, storage_size from safetensors.torch import storage_ptr, storage_size
from torch import nn from torch import nn
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
from .utils import is_torch_xla_available, logging from .utils import is_torch_xla_available, logging
@ -44,6 +39,14 @@ is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12") is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
if is_torch_greater_or_equal_than_2_4:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
)
def softmax_backward_data(parent, grad_output, output, dim, self): def softmax_backward_data(parent, grad_output, output, dim, self):
""" """
A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according