Refactor some core stuff (#36539)

* some config changes

* update

* current state

* update

* update

* updates and cleanup

* something that works

* fixup

* fixes

* nits

* nit

* nits and fix

* Update src/transformers/integrations/tensor_parallel.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Update src/transformers/integrations/tensor_parallel.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* cleanup

* style

* safe import

* fix

* updates

* rename stuff an clean

* style

* small updates

* ups

* oups

* nit

* protect imports

* update tp

* rodfl

* arf

* turbo nit on init

* fix import error

* frumble gumbgle

* try to fix the import error

* should fix the non model test

* update keep in float32

* update

* fix

* nits

* fix subvconfigs

* test was weird

* nit

* fix failing test

* fix instruct blip

* fixes

* style

* x.com

* fix overwrite

* ok last bit of failing test

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur 2025-03-11 09:26:28 +01:00 committed by GitHub
parent e9756cdbc7
commit 1c4b62b219
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 704 additions and 116 deletions

View File

@ -824,25 +824,27 @@ class PretrainedConfig(PushToHubMixin):
serializable_config_dict = {}
# only serialize values that differ from the default config
# Only serialize values that differ from the default config,
# except always keep the 'config' attribute.
for key, value in config_dict.items():
if (
isinstance(getattr(self, key, None), PretrainedConfig)
and key in class_config_dict
and isinstance(class_config_dict[key], dict)
or key in self.sub_configs
):
# For nested configs we need to clean the diff recursively
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
if "model_type" in value:
# Needs to be set even if it's not in the diff
diff["model_type"] = value["model_type"]
if len(diff) > 0:
serializable_config_dict[key] = diff
serializable_config_dict[key] = diff
elif (
key not in default_config_dict
or key == "transformers_version"
or key == "vocab_file"
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
or (key in default_config_dict and value != class_config_dict.get(key, value))
):
serializable_config_dict[key] = value
@ -867,6 +869,9 @@ class PretrainedConfig(PushToHubMixin):
if "base_model_pp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_pp_plan"]
if "_name_or_path" in serializable_config_dict:
del serializable_config_dict["_name_or_path"]
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
@ -1178,6 +1183,8 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
"""
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
values from `dict_a` that are different from values in `dict_b`.
dict_b : the default config dictionnary. We want to remove values that are in this one
"""
diff = {}
default = config_obj.__class__().to_dict() if config_obj is not None else {}
@ -1185,9 +1192,8 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
obj_value = getattr(config_obj, str(key), None)
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
if len(diff_value) > 0:
diff[key] = diff_value
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
diff[key] = diff_value
elif key not in dict_b or (value != default[key]):
diff[key] = value
return diff

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_torch_greater_or_equal
_import_structure = {
@ -128,6 +128,18 @@ else:
"convert_and_export_with_cache",
]
try:
if not is_torch_greater_or_equal("2.3"):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tensor_parallel"] = [
"shard_and_distribute_module",
"SUPPORTED_TP_STYLES",
"translate_to_torch_parallel_style",
]
if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear
from .awq import (
@ -231,6 +243,18 @@ if TYPE_CHECKING:
else:
from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
try:
if not is_torch_greater_or_equal("2.3"):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tensor_parallel import (
SUPPORTED_TP_STYLES,
shard_and_distribute_module,
translate_to_torch_parallel_style,
)
else:
import sys

View File

@ -0,0 +1,544 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import re
from functools import lru_cache, partial
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from ..utils import is_torch_greater_or_equal, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger = logging.get_logger(__name__)
# Cache this result has it's a C FFI call which can be pretty time-consuming
_torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (List[int]).
In the second case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
So if you have: gate_proj ( 16, 5120, 8190)
and up_proj ( 16, 5120, 8190)
packed as gate_up_proj ( 16, 5120, 2 * 8190)
And you shard along the last dimension, you need to interleave the gate and up values:
Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
Let's take TP_size = 4 for an example:
Packed tensor `gate_up_proj`
---------------------------------------------------------------
[ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
Explanation:
- The first half of the tensor (left of the center) holds the gate_proj values.
- The second half (right of the center) holds the up_proj values.
- For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
- Each shard receives one slice from the gate part and the corresponding slice from the up part.
For instance:
Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
and so on.
This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
"""
slice_ = param
total_size = empty_param.shape[dim]
world_size = device_mesh.size()
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
tensors_slices = []
block_offset = 0
for block_size in block_sizes:
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
tensor = slice_[:, tensors_slices, ...]
elif dim == 2 or dim == -1:
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
if dim == 0:
size_ = empty_param.shape[0]
param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...]
elif dim == 1 or dim == -2:
size_ = empty_param.shape[-2]
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :]
elif dim == 2 or dim == -1:
size_ = empty_param.shape[-1]
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return param
def distribute_module(
module: nn.Module,
device_mesh=None,
input_fn=None,
output_fn=None,
) -> nn.Module:
"""
Copy pasted from torch's function but we remove the communications (partitionning)
as well as buffer registering that is similarly not efficient.
"""
if len(module._forward_pre_hooks) == 0:
if input_fn is not None:
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
if output_fn is not None:
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
return module
class TensorParallelLayer:
"""
General tensor parallel layer for transformers.
"""
use_dtensor = True
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ...
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
raise NotImplementedError
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
if self.use_dtensor:
distribute_module(
module,
device_mesh,
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
)
# use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice
# you name it. Whatever you want to do that is a bit unconventional, you need local tensors
class GatherParallel(TensorParallelLayer):
"""
Simple class used to define the hooks to add to a layer when we just want to gather the outputs
"""
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = output_layouts
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
inputs[0] = inputs[0].to_local()
return inputs
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
return outputs
class IsolatedParallel(TensorParallelLayer):
"""
This class is used to isolate computation in a TP layer from the rest of the world.
Parameters need to be LOCAL, so not dtensors
"""
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh=None):
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if isinstance(input_tensor, DTensor):
input_tensor = input_tensor.to_local()
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh=None):
# TODO: figure out dynamo support for instance method and switch this to instance method
return outputs
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
device_mesh,
partial(self._prepare_input_fn),
partial(self._prepare_output_fn),
)
class ColwiseParallel(TensorParallelLayer):
"""
General tensor parallel layer for transformers.
"""
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
use_dtensor=True,
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Shard(-1),)
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# TODO: figure out dynamo support for instance method and switch this to instance method
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return nn.Parameter(parameter)
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
# back to local tensor
return outputs.to_local() if use_local_output else outputs
class PackedColwiseParallel(ColwiseParallel):
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
return nn.Parameter(parameter)
class RowwiseParallel(TensorParallelLayer):
"""
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
(i.e. MLP, Attention)
Keyword Args:
input_layouts (Placement, optional):
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
output_layouts (Placement, optional):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is replicated.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
"""
def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
use_dtensor=True,
):
super().__init__()
self.input_layouts = (input_layouts or Shard(-1),)
self.output_layouts = (output_layouts or Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
if param_type != "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Replicate()]
parameter = param[:]
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return nn.Parameter(parameter)
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# Rowwise sharding produces partial output, depending on output layouts:
# 1. to replicate -> allreduce
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
module._distribute_module_applied = True
if self.use_dtensor:
if isinstance(module, nn.Linear):
# rowwise linear runtime sharding requires input tensor shard on last dim
self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),)
elif isinstance(module, nn.Embedding):
# rowwise embedding runtime sharding requires input tensor replicated
self.desired_input_layouts = (Replicate(),)
elif isinstance(module, nn.Parameter):
# rowwise embedding runtime sharding requires input tensor replicated
self.desired_input_layouts = (Shard(-1),)
else:
raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
distribute_module(
module,
device_mesh,
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
)
class PackedRowwiseParallel(RowwiseParallel):
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -1)
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
return nn.Parameter(parameter)
SUPPORTED_TP_STYLES = {
"colwise",
"rowwise",
"colwise_rep",
"rowwise_rep",
"local_colwise",
"local_rowwise",
"local",
"gather",
"local_packed_rowwise",
}
@lru_cache
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
elif style == "local_colwise":
return ColwiseParallel(use_dtensor=False)
elif style == "local_rowwise":
return RowwiseParallel(use_dtensor=False)
elif style == "local":
return IsolatedParallel()
elif style == "gather":
return GatherParallel()
elif style == "local_packed_rowwise":
return PackedRowwiseParallel(use_dtensor=False)
else:
raise ValueError(f"Unsupported parallel style value: {style}")
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
"""
Add hooks to the module holding the layer. Meaning:
```
class MyModel(nn.Module):
def __init__(self):
self.layer = nn.Linear(10, 10)
```
has state_dict like:
```
{
"layer.weight": torch.Tensor,
"layer.bias": torch.Tensor
}
```
we add hooks to `MyModel` as well as `layer` to make sure that the tensors are correctly sharded and gathered.
"""
# 1. We add hooks to the layer being loaded:
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
tp_layer.prepare_module_tp(module, device_mesh)
# 2. We add hooks to the parrent module if needed
if "." in layer_name:
parrent_layer_name = layer_name.rsplit(".", 1)[0]
generic_name = re.sub(r"\d+", "*", parrent_layer_name)
# The module itself needs hooks
if module_plan := tp_plan.get(generic_name, False):
tp_layer = translate_to_torch_parallel_style(module_plan)
module_to_tp_ = model.get_submodule(parrent_layer_name)
tp_layer.prepare_module_tp(module_to_tp_, device_mesh)
def shard_and_distribute_module(
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
):
r"""
Main uses cases:
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
- packed layers: you slice the weights, then shard like above
- custom operation:
- you want to add an all-gather at the end of a local layer.
- you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
"""
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
if not getattr(module_to_tp, "_is_hooked", False):
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
module_to_tp._is_hooked = True
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
else:
param = param[:]
if is_contiguous:
param = param.contiguous()
# SUPER IMPORTANT we have to use setattr
# otherwise loading is crazy slow
if not isinstance(param, torch.nn.Parameter):
param = torch.nn.Parameter(param)
setattr(module_to_tp, param_type, param)
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
return param

View File

@ -54,17 +54,20 @@ from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.tensor_parallel import (
SUPPORTED_TP_STYLES,
shard_and_distribute_module,
translate_to_torch_parallel_style,
)
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
distribute_module,
find_pruneable_heads_and_indices,
id_tensor_storage,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
translate_to_torch_parallel_style,
)
from .quantizers import AutoHfQuantizer, HfQuantizer
from .quantizers.quantizers_utils import get_module_from_name
@ -151,6 +154,7 @@ logger = logging.get_logger(__name__)
_init_weights = True
_is_quantized = False
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()
def is_fsdp_enabled():
@ -181,8 +185,6 @@ else:
if is_peft_available():
from .utils import find_adapter_config_file
if is_torch_greater_or_equal("2.5"):
from torch.distributed.tensor import DTensor, Shard
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
@ -756,6 +758,40 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
setattr(submodule, param_name, new_val)
def fix_tensor_type_and_device(
model, param_name, param, dtype=None, keep_in_fp32_modules=None
) -> Union[str, torch.dtype]:
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
if "." in param_name:
pre, _ = param_name.rsplit(".", 1)
old_param = model.get_submodule(pre)
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
param_casting_dtype = torch.float32
elif dtype is not None:
param_casting_dtype = dtype
elif old_param is not None:
param_casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), param_casting_dtype
else:
return False, None
return
@torch.no_grad()
def _load_state_dict_into_meta_model(
model: torch.nn.Module,
@ -787,18 +823,12 @@ def _load_state_dict_into_meta_model(
It also initialize tensor parallelism for each module if needed.
"""
tensor_device = None
tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
# we need this later to initialize tensor parallelism
if device_mesh is not None:
full_tp_plan = model.config.base_model_tp_plan
for submodule in model.modules():
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
file_pointer = None
bin_state_dict = None
if shard_file.endswith(".safetensors"):
@ -818,8 +848,6 @@ def _load_state_dict_into_meta_model(
is_quantized = hf_quantizer is not None
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for serialized_param_name, empty_param in state_dict.items():
# serialized_param_name is the raw, serialized name
# fixed_param_name is the model's equivalent
@ -829,87 +857,37 @@ def _load_state_dict_into_meta_model(
continue
# we need to use serialized_param_name as file pointer is untouched
param = (
file_pointer.get_slice(serialized_param_name)
if shard_file.endswith(".safetensors")
else bin_state_dict[serialized_param_name]
if shard_file.endswith(".safetensors"):
param = file_pointer.get_slice(serialized_param_name)
elif shard_file.endswith(".gguf"):
param = empty_param # For gguf the dict is actually not empty!
else:
param = bin_state_dict[serialized_param_name]
to_contiguous, param_casting_dtype = fix_tensor_type_and_device(
model,
param_name=fixed_param_name,
param=empty_param,
dtype=dtype,
keep_in_fp32_modules=keep_in_fp32_modules,
)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
old_param = model
splits = fixed_param_name.split(".")
for split in splits:
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
old_param = getattr(old_param, split, None)
if old_param is None:
break
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
old_param = None
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
param_casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
param_casting_dtype = torch.float32
elif dtype is not None:
param_casting_dtype = dtype
elif old_param is not None:
param_casting_dtype = old_param.dtype
if device_mesh is not None: # In this case, the param is already on the correct device!
module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
current_module_plan = None
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
if plan := re.search(full_tp_plan_, fixed_param_name):
match = re.sub("[0-9]+", "*", plan[0])
current_module_plan = full_tp_plan[match]
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
rank = tensor_device
row, col = empty_param.shape
if "rowwise" == current_module_plan:
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
shard = Shard(1)
tp_layer.desired_input_layouts = (Shard(-1),)
elif "colwise" == current_module_plan:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
else:
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
shard = Shard(0)
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param.is_contiguous():
param = param.contiguous()
local_parameter = DTensor.from_local(
param,
device_mesh=device_mesh,
placements=[shard] * device_mesh.ndim,
)
if isinstance(module_to_tp.weight, nn.Parameter):
local_parameter = torch.nn.Parameter(local_parameter)
module_to_tp.weight = local_parameter
input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
param = param[:]
if old_param is not None and old_param.is_contiguous():
param = param.contiguous()
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
shard_and_distribute_module(
model,
param,
empty_param,
fixed_param_name,
param_casting_dtype,
to_contiguous,
tensor_device, # the rank
device_mesh,
)
else:
param = param[:]
if param_casting_dtype is not None:
param = param.to(param_casting_dtype)
if old_param is not None and old_param.is_contiguous():
if to_contiguous:
param = param.contiguous()
if device_map is None:
@ -966,6 +944,7 @@ def _load_state_dict_into_meta_model(
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, param_type, value)
if file_pointer is not None:
file_pointer.__exit__(None, None, None)
@ -1409,7 +1388,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
# `config.base_model_tp_plan` during `post_init`.
# `config.base_model_tp_plan` during `__init__`.
# It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
# by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
# for example.
_tp_plan = None
# A pipeline parallel plan specifying the layers which may not be present
@ -1475,6 +1457,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._no_split_modules = self._no_split_modules or []
def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
@ -1482,11 +1466,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan
self._pp_plan = self.config.base_model_pp_plan
self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
if v not in SUPPORTED_TP_STYLES:
raise ValueError(
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
)
def dequantize(self):
"""
Potentially dequantize the model in case it has been quantized by a quantization method that support
@ -4315,7 +4311,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model = cls(config, *model_args, **model_kwargs)
if device_mesh is not None and not model.supports_tp_plan:
raise NotImplementedError("This model does not have a tensor parallel plan.")
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
raise NotImplementedError("This model does not have a tensor parallel plan.")
# make sure we use the model's config since the __init__ call might have copied it
config = model.config
@ -4453,7 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
resolved_archive_file or gguf_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
@ -4565,7 +4562,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@staticmethod
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
@ -4590,6 +4586,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return key, False
def rename_key(self, key):
"""
When we load a LlamaModel from a checkpoint made using LlamaForCausalLM, the keys have an extra
prefix, which can be accessed in the `LlamaModel` via the `self.base_model_prefix` attribute.
But, what if there is an extra layer on top of it? You load a MistralModel from a LlavaForConditionalGeneration?
In that what you actually want is to cut whatever is left of the key.
"""
new_key = key
if len(self.base_model_prefix) > 0:
if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix):
@ -4940,7 +4943,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
resolved_archive_file=resolved_archive_file,
shard_file=resolved_archive_file,
weights_only=weights_only,
)
else:
@ -5019,7 +5022,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
start_prefix,
prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
@ -5898,10 +5901,21 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
for param_name, device in accelerator_device_map.items():
try:
param = model.get_parameter(param_name)
param = getattr(model, param_name)
except AttributeError:
param = model.get_buffer(param_name)
parameter_count[device] += int(math.prod(param.shape) * allocation_factor)
if "." in param_name:
param_name, param_type = param_name.rsplit(".", 1)
param = getattr(model.get_submodule(param_name), param_type)
else:
param = model.get_buffer(param_name)
param_size = int(math.prod(param.shape) * allocation_factor)
if _torch_distributed_available and torch.distributed.is_initialized():
generic_name = re.sub(r"\d+", "*", param_name)
param_size //= torch.distributed.get_world_size() if not model._tp_plan.get(generic_name, False) else 1
parameter_count[device] += param_size
dtype = dtype if dtype is not None else torch.float32

View File

@ -419,7 +419,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
"OPTDecoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"]
_keep_in_fp32_modules = ["query_tokens"]
def _init_weights(self, module):
"""Initialize the weights"""
@ -1799,7 +1799,7 @@ class Blip2Model(Blip2PreTrainedModel):
)
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
supports_gradient_checkpointing = False
_keep_in_fp32_modules = []
_keep_in_fp32_modules = ["query_tokens"]
def __init__(self, config: Blip2Config):
super().__init__(config)
@ -1898,7 +1898,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
)
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = []
_keep_in_fp32_modules = ["query_tokens"]
def __init__(self, config: Blip2Config):
super().__init__(config)
@ -2371,7 +2371,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
)
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
main_input_name = "pixel_values"
_keep_in_fp32_modules = []
_keep_in_fp32_modules = ["query_tokens"]
def __init__(self, config: Blip2Config):
super().__init__(config)

View File

@ -322,7 +322,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
"InstructBlipQFormerMultiHeadAttention",
"InstructBlipQFormerSelfOutput",
]
_keep_in_fp32_modules = []
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module):
@ -1293,6 +1292,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipConfig):
super().__init__(config)

View File

@ -323,7 +323,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
"InstructBlipVideoQFormerMultiHeadAttention",
"InstructBlipVideoQFormerSelfOutput",
]
_keep_in_fp32_modules = []
def _init_weights(self, module):
"""Initialize the weights"""
@ -1287,6 +1286,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)

View File

@ -58,13 +58,13 @@ class LlavaConfigTest(unittest.TestCase):
"""
Simple test for reloading arbirarily composed subconfigs
"""
default_values = LlavaConfig().to_dict()
default_values["vision_config"]["model_type"] = "qwen2_vl"
default_values = LlavaConfig().to_diff_dict()
default_values["vision_config"]["model_type"] = "pixtral"
default_values["text_config"]["model_type"] = "opt"
self.maxDiff = None
with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(**default_values)
config.save_pretrained(tmp_dir)
reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()
self.assertDictEqual(config.to_dict(), reloaded.to_dict())