mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
parent
214062201e
commit
eefc86aa31
@ -307,8 +307,7 @@ class ColwiseParallel(TensorParallelLayer):
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
requires_grad = True if parameter.is_floating_point() else False
|
||||
return nn.Parameter(parameter, requires_grad=requires_grad)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
@ -330,8 +329,7 @@ class PackedColwiseParallel(ColwiseParallel):
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
|
||||
requires_grad = True if parameter.is_floating_point() else False
|
||||
return nn.Parameter(parameter, requires_grad=requires_grad)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
|
||||
class RowwiseParallel(TensorParallelLayer):
|
||||
@ -383,8 +381,7 @@ class RowwiseParallel(TensorParallelLayer):
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
requires_grad = True if parameter.is_floating_point() else False
|
||||
return nn.Parameter(parameter, requires_grad=requires_grad)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
@ -446,8 +443,7 @@ class PackedRowwiseParallel(RowwiseParallel):
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
|
||||
requires_grad = True if parameter.is_floating_point() else False
|
||||
return nn.Parameter(parameter, requires_grad=requires_grad)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
|
||||
class SequenceParallel(TensorParallelLayer):
|
||||
@ -531,8 +527,7 @@ class SequenceParallel(TensorParallelLayer):
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
|
||||
requires_grad = True if parameter.is_floating_point() else False
|
||||
return nn.Parameter(parameter, requires_grad=requires_grad)
|
||||
return nn.Parameter(parameter, requires_grad=parameter.is_floating_point())
|
||||
|
||||
|
||||
SUPPORTED_TP_STYLES = {
|
||||
@ -671,7 +666,7 @@ def shard_and_distribute_module(
|
||||
# SUPER IMPORTANT we have to use setattr
|
||||
# otherwise loading is crazy slow
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
param = torch.nn.Parameter(param)
|
||||
param = torch.nn.Parameter(param, requires_grad=param.is_floating_point())
|
||||
setattr(module_to_tp, param_type, param)
|
||||
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
return param
|
||||
|
Loading…
Reference in New Issue
Block a user