mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
feat: support indivisible shards for TP model loading and TPlizing. (#37220)
* feat: support uneven loading and sharding resolve merge conflicts Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: allow for empty tensor computations Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * test: add llama1b test case Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * due to q_proj colwise it has to be multi of 2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * refactor: use slice API Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
parent
06c4a4d499
commit
def9663239
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
@ -280,7 +281,48 @@ def repack_weights(
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
Generalized tensor sharding across a multi-dimensional device mesh.
|
||||
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
|
||||
Extraction follows the pytorch `Shard` placement so that sharding and materializing back to full tensor follows `Shard` semantics.
|
||||
`Shard` follows torch.chunk style sharding of the tensor. We demonstrate some cases below on how sharding happens including some edge cases
|
||||
such as some ranks having an empty tensor as shard. Below implementation is robut to all these cases.
|
||||
|
||||
Case (1)
|
||||
empty_param (16, 5120, 8190)
|
||||
dim 0
|
||||
device_mesh.size() 4
|
||||
rank 0 gets (4, 5120, 8190) (0 ... 4, 5120, 8190)
|
||||
rank 1 gets (4, 5120, 8190) (4 ... 8, 5120, 8190)
|
||||
rank 2 gets (4, 5120, 8190) (8 ... 12, 5120, 8190)
|
||||
rank 3 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
||||
|
||||
Case (2)
|
||||
empty_param (16, 5120, 8190)
|
||||
dim 0
|
||||
device_mesh.size() 14
|
||||
rank 0 gets (2, 5120, 8190) (0 ... 2, 5120, 8190)
|
||||
rank 1 gets (2, 5120, 8190) (2 ... 4, 5120, 8190)
|
||||
rank 2 gets (2, 5120, 8190) (4 ... 6, 5120, 8190)
|
||||
rank 3 gets (2, 5120, 8190) (6 ... 8, 5120, 8190)
|
||||
rank 4 gets (2, 5120, 8190) (8 ... 10, 5120, 8190)
|
||||
rank 5 gets (2, 5120, 8190) (10 ... 12, 5120, 8190)
|
||||
rank 6 gets (2, 5120, 8190) (12 ... 14, 5120, 8190)
|
||||
rank 7 gets (2, 5120, 8190) (14 ... 16, 5120, 8190)
|
||||
rank 8 gets (0, 5120, 8190)
|
||||
rank 9 gets (0, 5120, 8190)
|
||||
rank 10 gets (0, 5120, 8190)
|
||||
rank 11 gets (0, 5120, 8190)
|
||||
rank 12 gets (0, 5120, 8190)
|
||||
rank 13 gets (0, 5120, 8190)
|
||||
|
||||
Case (3)
|
||||
empty_param (16, 5120, 8190)
|
||||
dim 0
|
||||
device_mesh.size() 3
|
||||
rank 0 gets (6, 5120, 8190) (0 ... 6, 5120, 8190)
|
||||
rank 1 gets (6, 5120, 8190) (6 ... 12, 5120, 8190)
|
||||
rank 2 gets (4, 5120, 8190) (12 ... 16, 5120, 8190)
|
||||
|
||||
In case (2), empty shards are returned with appropriate dimension to allow for operations to work smoothly.
|
||||
Args:
|
||||
param (torch.Tensor): The tensor to shard.
|
||||
empty_param (torch.Tensor): A tensor used for shape reference.
|
||||
@ -289,6 +331,7 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
dim (int): Dimension along which to shard the tensor.
|
||||
"""
|
||||
param_dim = empty_param.dim()
|
||||
|
||||
if dim < 0:
|
||||
dim = param_dim + dim
|
||||
if dim >= param_dim:
|
||||
@ -301,15 +344,18 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
if rank >= world_size:
|
||||
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
|
||||
|
||||
shard_size = empty_param.shape[dim] // world_size
|
||||
shard_size = math.ceil(empty_param.shape[dim] / world_size)
|
||||
start = rank * shard_size
|
||||
end = start + shard_size
|
||||
|
||||
# Construct slicing index dynamically
|
||||
end = min(start + shard_size, empty_param.shape[dim])
|
||||
slice_indices = [slice(None)] * param_dim
|
||||
if start < empty_param.shape[dim]:
|
||||
slice_indices[dim] = slice(start, end)
|
||||
|
||||
return param[tuple(slice_indices)]
|
||||
dimensions = list(param.shape)
|
||||
dimensions[dim] = 0
|
||||
return torch.empty(tuple(dimensions), dtype=torch.int64)
|
||||
|
||||
|
||||
def distribute_module(
|
||||
@ -500,7 +546,9 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
parameter = DTensor.from_local(
|
||||
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
|
||||
)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
@staticmethod
|
||||
@ -574,7 +622,9 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
parameter = DTensor.from_local(
|
||||
parameter, device_mesh, shard, run_check=False, shape=empty_param.size(), stride=empty_param.stride()
|
||||
)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user