mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
feat: support indivisible shards for TP model loading and TPlizing. (#37220)
* feat: support uneven loading and sharding resolve merge conflicts Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: allow for empty tensor computations Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * test: add llama1b test case Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * due to q_proj colwise it has to be multi of 2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
parent
06c4a4d499
commit
def9663239
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -280,7 +281,48 @@ def repack_weights(
|
|||||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||||
"""
|
"""
|
||||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||||
|
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||||
|
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
|
||||||
|
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
|
||||||
|
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
|
||||||
|
|
||||||
|
Case (1)
|
||||||
|
empty_param (16, 5120, 8190)
|
||||||
|
dim 0
|
||||||
|
device_mesh.size() 4
|
||||||
|
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
|
||||||
|
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
|
||||||
|
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
|
||||||
|
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
||||||
|
|
||||||
|
Case (2)
|
||||||
|
empty_param (16, 5120, 8190)
|
||||||
|
dim 0
|
||||||
|
device_mesh.size() 14
|
||||||
|
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
|
||||||
|
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
|
||||||
|
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
|
||||||
|
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
|
||||||
|
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
|
||||||
|
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
|
||||||
|
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
|
||||||
|
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
|
||||||
|
rank 8 gets (0, 5120, 8190)
|
||||||
|
rank 9 gets (0, 5120, 8190)
|
||||||
|
rank 10 gets (0, 5120, 8190)
|
||||||
|
rank 11 gets (0, 5120, 8190)
|
||||||
|
rank 12 gets (0, 5120, 8190)
|
||||||
|
rank 13 gets (0, 5120, 8190)
|
||||||
|
|
||||||
|
Case (3)
|
||||||
|
empty_param (16, 5120, 8190)
|
||||||
|
dim 0
|
||||||
|
device_mesh.size() 3
|
||||||
|
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
|
||||||
|
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
|
||||||
|
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
||||||
|
|
||||||
|
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
|
||||||
Args:
|
Args:
|
||||||
param (torch.Tensor): The tensor to shard.
|
param (torch.Tensor): The tensor to shard.
|
||||||
empty_param (torch.Tensor): A tensor used for shape reference.
|
empty_param (torch.Tensor): A tensor used for shape reference.
|
||||||
@ -289,6 +331,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
|||||||
dim (int): Dimension along which to shard the tensor.
|
dim (int): Dimension along which to shard the tensor.
|
||||||
"""
|
"""
|
||||||
param_dim = empty_param.dim()
|
param_dim = empty_param.dim()
|
||||||
|
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
dim = param_dim + dim
|
dim = param_dim + dim
|
||||||
if dim >= param_dim:
|
if dim >= param_dim:
|
||||||
@ -301,15 +344,18 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
|||||||
if rank >= world_size:
|
if rank >= world_size:
|
||||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||||
|
|
||||||
shard_size = empty_param.shape[dim] // world_size
|
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||||
start = rank * shard_size
|
start = rank * shard_size
|
||||||
end = start + shard_size
|
|
||||||
|
|
||||||
# Construct slicing index dynamically
|
# Construct slicing index dynamically
|
||||||
|
end = min(start + shard_size, empty_param.shape[dim])
|
||||||
slice_indices = [slice(None)] * param_dim
|
slice_indices = [slice(None)] * param_dim
|
||||||
slice_indices[dim] = slice(start, end)
|
if start < empty_param.shape[dim]:
|
||||||
|
slice_indices[dim] = slice(start, end)
|
||||||
return param[tuple(slice_indices)]
|
return param[tuple(slice_indices)]
|
||||||
|
dimensions = list(param.shape)
|
||||||
|
dimensions[dim] = 0
|
||||||
|
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||||
|
|
||||||
|
|
||||||
def distribute_module(
|
def distribute_module(
|
||||||
@ -500,7 +546,9 @@ class ColwiseParallel(TensorParallelLayer):
|
|||||||
if to_contiguous:
|
if to_contiguous:
|
||||||
parameter = parameter.contiguous()
|
parameter = parameter.contiguous()
|
||||||
if self.use_dtensor:
|
if self.use_dtensor:
|
||||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
parameter = DTensor.from_local(
|
||||||
|
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
|
||||||
|
)
|
||||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -574,7 +622,9 @@ class RowwiseParallel(TensorParallelLayer):
|
|||||||
if to_contiguous:
|
if to_contiguous:
|
||||||
parameter = parameter.contiguous()
|
parameter = parameter.contiguous()
|
||||||
if self.use_dtensor:
|
if self.use_dtensor:
|
||||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
parameter = DTensor.from_local(
|
||||||
|
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
|
||||||
|
)
|
||||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user