feat: support indivisible shards for TP model loading and TPlizing. (#37220)

* feat: support uneven loading and sharding
resolve merge conflicts
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: allow for empty tensor computations

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* test: add llama1b test case

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* due to q_proj colwise it has to be multi of 2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* refactor: use slice API

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* refactor: use slice API

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* refactor: use slice API

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* refactor: use slice API

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
Mehant Kammakomati 2025-07-01 15:33:22 +05:30 committed by GitHub
parent 06c4a4d499
commit def9663239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations
import math
import operator
import os
import re
@ -280,7 +281,48 @@ def repack_weights(
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
Case (1)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 4
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
Case (2)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 14
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
rank 8 gets (0, 5120, 8190)
rank 9 gets (0, 5120, 8190)
rank 10 gets (0, 5120, 8190)
rank 11 gets (0, 5120, 8190)
rank 12 gets (0, 5120, 8190)
rank 13 gets (0, 5120, 8190)
Case (3)
empty_param (16, 5120, 8190)
dim 0
device_mesh.size() 3
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
Args:
param (torch.Tensor): The tensor to shard.
empty_param (torch.Tensor): A tensor used for shape reference.
@ -289,6 +331,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.dim()
if dim < 0:
dim = param_dim + dim
if dim >= param_dim:
@ -301,15 +344,18 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
shard_size = empty_param.shape[dim] // world_size
shard_size = math.ceil(empty_param.shape[dim] / world_size)
start = rank * shard_size
end = start + shard_size
# Construct slicing index dynamically
end = min(start + shard_size, empty_param.shape[dim])
slice_indices = [slice(None)] * param_dim
slice_indices[dim] = slice(start, end)
return param[tuple(slice_indices)]
if start < empty_param.shape[dim]:
slice_indices[dim] = slice(start, end)
return param[tuple(slice_indices)]
dimensions = list(param.shape)
dimensions[dim] = 0
return torch.empty(tuple(dimensions), dtype=torch.int64)
def distribute_module(
@ -500,7 +546,9 @@ class ColwiseParallel(TensorParallelLayer):
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
parameter = DTensor.from_local(
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod
@ -574,7 +622,9 @@ class RowwiseParallel(TensorParallelLayer):
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
parameter = DTensor.from_local(
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod