Fix tensor parallel with non-floating dtypes (#37790)

fix
This commit is contained in:
Cyril Vallez 2025-04-25 15:48:16 +02:00 committed by GitHub
parent 214062201e
commit eefc86aa31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -307,8 +307,7 @@ class ColwiseParallel(TensorParallelLayer):
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
requires_grad = True if parameter.is_floating_point() else False
return nn.Parameter(parameter, requires_grad=requires_grad)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
@ -330,8 +329,7 @@ class PackedColwiseParallel(ColwiseParallel):
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
requires_grad = True if parameter.is_floating_point() else False
return nn.Parameter(parameter, requires_grad=requires_grad)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
class RowwiseParallel(TensorParallelLayer):
@ -383,8 +381,7 @@ class RowwiseParallel(TensorParallelLayer):
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
requires_grad = True if parameter.is_floating_point() else False
return nn.Parameter(parameter, requires_grad=requires_grad)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
@ -446,8 +443,7 @@ class PackedRowwiseParallel(RowwiseParallel):
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
requires_grad = True if parameter.is_floating_point() else False
return nn.Parameter(parameter, requires_grad=requires_grad)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
class SequenceParallel(TensorParallelLayer):
@ -531,8 +527,7 @@ class SequenceParallel(TensorParallelLayer):
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
requires_grad = True if parameter.is_floating_point() else False
return nn.Parameter(parameter, requires_grad=requires_grad)
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
SUPPORTED_TP_STYLES = {
@ -671,7 +666,7 @@ def shard_and_distribute_module(
# SUPER IMPORTANT we have to use setattr
# otherwise loading is crazy slow
if not isinstance(param, torch.nn.Parameter):
param = torch.nn.Parameter(param)
param = torch.nn.Parameter(param, requires_grad=param.is_floating_point())
setattr(module_to_tp, param_type, param)
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
return param