mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Refactor some core stuff (#36539)
* some config changes * update * current state * update * update * updates and cleanup * something that works * fixup * fixes * nits * nit * nits and fix * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut <hi@lysand.re> * Update src/transformers/integrations/tensor_parallel.py Co-authored-by: Lysandre Debut <hi@lysand.re> * cleanup * style * safe import * fix * updates * rename stuff an clean * style * small updates * ups * oups * nit * protect imports * update tp * rodfl * arf * turbo nit on init * fix import error * frumble gumbgle * try to fix the import error * should fix the non model test * update keep in float32 * update * fix * nits * fix subvconfigs * test was weird * nit * fix failing test * fix instruct blip * fixes * style * x.com * fix overwrite * ok last bit of failing test --------- Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
parent
e9756cdbc7
commit
1c4b62b219
@ -824,25 +824,27 @@ class PretrainedConfig(PushToHubMixin):
|
||||
|
||||
serializable_config_dict = {}
|
||||
|
||||
# only serialize values that differ from the default config
|
||||
# Only serialize values that differ from the default config,
|
||||
# except always keep the 'config' attribute.
|
||||
for key, value in config_dict.items():
|
||||
if (
|
||||
isinstance(getattr(self, key, None), PretrainedConfig)
|
||||
and key in class_config_dict
|
||||
and isinstance(class_config_dict[key], dict)
|
||||
or key in self.sub_configs
|
||||
):
|
||||
# For nested configs we need to clean the diff recursively
|
||||
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
|
||||
diff = recursive_diff_dict(value, default_config_dict, config_obj=getattr(self, key, None))
|
||||
if "model_type" in value:
|
||||
# Needs to be set even if it's not in the diff
|
||||
diff["model_type"] = value["model_type"]
|
||||
if len(diff) > 0:
|
||||
serializable_config_dict[key] = diff
|
||||
serializable_config_dict[key] = diff
|
||||
elif (
|
||||
key not in default_config_dict
|
||||
or key == "transformers_version"
|
||||
or key == "vocab_file"
|
||||
or value != default_config_dict[key]
|
||||
or (key in class_config_dict and value != class_config_dict[key])
|
||||
or (key in default_config_dict and value != class_config_dict.get(key, value))
|
||||
):
|
||||
serializable_config_dict[key] = value
|
||||
|
||||
@ -867,6 +869,9 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if "base_model_pp_plan" in serializable_config_dict:
|
||||
del serializable_config_dict["base_model_pp_plan"]
|
||||
|
||||
if "_name_or_path" in serializable_config_dict:
|
||||
del serializable_config_dict["_name_or_path"]
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -1178,6 +1183,8 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
|
||||
"""
|
||||
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
|
||||
values from `dict_a` that are different from values in `dict_b`.
|
||||
|
||||
dict_b : the default config dictionnary. We want to remove values that are in this one
|
||||
"""
|
||||
diff = {}
|
||||
default = config_obj.__class__().to_dict() if config_obj is not None else {}
|
||||
@ -1185,9 +1192,8 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
|
||||
obj_value = getattr(config_obj, str(key), None)
|
||||
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
|
||||
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
|
||||
if len(diff_value) > 0:
|
||||
diff[key] = diff_value
|
||||
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
|
||||
diff[key] = diff_value
|
||||
elif key not in dict_b or (value != default[key]):
|
||||
diff[key] = value
|
||||
return diff
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_torch_greater_or_equal
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -128,6 +128,18 @@ else:
|
||||
"convert_and_export_with_cache",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_torch_greater_or_equal("2.3"):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["tensor_parallel"] = [
|
||||
"shard_and_distribute_module",
|
||||
"SUPPORTED_TP_STYLES",
|
||||
"translate_to_torch_parallel_style",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .aqlm import replace_with_aqlm_linear
|
||||
from .awq import (
|
||||
@ -231,6 +243,18 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
|
||||
|
||||
try:
|
||||
if not is_torch_greater_or_equal("2.3"):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
shard_and_distribute_module,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
544
src/transformers/integrations/tensor_parallel.py
Normal file
544
src/transformers/integrations/tensor_parallel.py
Normal file
@ -0,0 +1,544 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import lru_cache, partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_greater_or_equal, logging
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Cache this result has it's a C FFI call which can be pretty time-consuming
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
|
||||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
|
||||
from torch.distributed.tensor import DTensor, Placement, Replicate, Shard
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
"""
|
||||
Convert block count or proportions to block sizes.
|
||||
|
||||
This function accepts
|
||||
|
||||
- The number of blocks (int), in which case the block size is
|
||||
total_size//blocks; or
|
||||
- A list of block sizes (List[int]).
|
||||
|
||||
In the second case, if sum(blocks) < total_size, the ratios between
|
||||
the block sizes will be preserved. For instance, if blocks is
|
||||
[2, 1, 1] and total_size is 1024, the returned block sizes are
|
||||
[512, 256, 256].
|
||||
"""
|
||||
if isinstance(blocks, list):
|
||||
total_blocks = sum(blocks)
|
||||
assert total_size % total_blocks == 0, f"Cannot split {total_size} in proportional blocks: {blocks}"
|
||||
part_size = total_size // total_blocks
|
||||
return [part_size * block for block in blocks]
|
||||
else:
|
||||
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
|
||||
single_size = total_size // blocks
|
||||
return [single_size] * blocks
|
||||
|
||||
|
||||
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
|
||||
"""
|
||||
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
|
||||
So if you have: gate_proj ( 16, 5120, 8190)
|
||||
and up_proj ( 16, 5120, 8190)
|
||||
packed as gate_up_proj ( 16, 5120, 2 * 8190)
|
||||
And you shard along the last dimension, you need to interleave the gate and up values:
|
||||
|
||||
Now, if we shard along the last dimension across TP_size (Tensor Parallelism size), we must interleave the values from gate and up projections correctly.
|
||||
|
||||
Let's take TP_size = 4 for an example:
|
||||
|
||||
Packed tensor `gate_up_proj`
|
||||
---------------------------------------------------------------
|
||||
[ G0 G1 G2 G3 | G4 G5 G6 G7 | ... | U0 U1 U2 U3 | U4 U5 U6 U7 | ... ]
|
||||
↑─────────────↑ ↑─────────────↑ ↑─────────────↑ ↑─────────────↑
|
||||
Gate Slice 0 Gate Slice 1 Up Slice 0 Up Slice 1
|
||||
|
||||
Explanation:
|
||||
- The first half of the tensor (left of the center) holds the gate_proj values.
|
||||
- The second half (right of the center) holds the up_proj values.
|
||||
- For TP=4, we divide each half into 4 slices. In this example, we show two slices for brevity.
|
||||
- Each shard receives one slice from the gate part and the corresponding slice from the up part.
|
||||
|
||||
For instance:
|
||||
• Shard 0 gets: [ Gate Slice 0, Up Slice 0 ] = [ G0, G1, G2, G3, U0, U1, U2, U3 ]
|
||||
• Shard 1 gets: [ Gate Slice 1, Up Slice 1 ] = [ G4, G5, G6, G7, U4, U5, U6, U7 ]
|
||||
• … and so on.
|
||||
|
||||
This ensures that each shard receives an equal portion of both gate and up projections, maintaining consistency across tensor parallelism.
|
||||
"""
|
||||
slice_ = param
|
||||
total_size = empty_param.shape[dim]
|
||||
world_size = device_mesh.size()
|
||||
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=2)
|
||||
|
||||
tensors_slices = []
|
||||
block_offset = 0
|
||||
for block_size in block_sizes:
|
||||
shard_block_size = block_size // world_size
|
||||
start = rank * shard_block_size
|
||||
stop = (rank + 1) * shard_block_size
|
||||
tensors_slices += range(block_offset + start, block_offset + stop)
|
||||
block_offset += block_size
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[tensors_slices, ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
tensor = slice_[:, tensors_slices, ...]
|
||||
elif dim == 2 or dim == -1:
|
||||
tensor = slice_[..., tensors_slices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||
return tensor
|
||||
|
||||
|
||||
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
|
||||
if dim == 0:
|
||||
size_ = empty_param.shape[0]
|
||||
param = param[rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), ...]
|
||||
elif dim == 1 or dim == -2:
|
||||
size_ = empty_param.shape[-2]
|
||||
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size()), :]
|
||||
elif dim == 2 or dim == -1:
|
||||
size_ = empty_param.shape[-1]
|
||||
param = param[..., rank * (size_ // device_mesh.size()) : (rank + 1) * (size_ // device_mesh.size())]
|
||||
else:
|
||||
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
|
||||
return param
|
||||
|
||||
|
||||
def distribute_module(
|
||||
module: nn.Module,
|
||||
device_mesh=None,
|
||||
input_fn=None,
|
||||
output_fn=None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Copy pasted from torch's function but we remove the communications (partitionning)
|
||||
as well as buffer registering that is similarly not efficient.
|
||||
"""
|
||||
if len(module._forward_pre_hooks) == 0:
|
||||
if input_fn is not None:
|
||||
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
|
||||
if output_fn is not None:
|
||||
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
|
||||
return module
|
||||
|
||||
|
||||
class TensorParallelLayer:
|
||||
"""
|
||||
General tensor parallel layer for transformers.
|
||||
"""
|
||||
|
||||
use_dtensor = True
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): ...
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
raise NotImplementedError
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
if self.use_dtensor:
|
||||
distribute_module(
|
||||
module,
|
||||
device_mesh,
|
||||
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
|
||||
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
|
||||
)
|
||||
|
||||
|
||||
# use_dtensor needs to be set to false for nn.Parameter when you want to view, chunk, slice
|
||||
# you name it. Whatever you want to do that is a bit unconventional, you need local tensors
|
||||
class GatherParallel(TensorParallelLayer):
|
||||
"""
|
||||
Simple class used to define the hooks to add to a layer when we just want to gather the outputs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None,
|
||||
output_layouts: Optional[Placement] = None,
|
||||
use_local_output: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = output_layouts
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
if isinstance(inputs[0], DTensor):
|
||||
inputs[0] = inputs[0].to_local()
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
|
||||
return outputs
|
||||
|
||||
|
||||
class IsolatedParallel(TensorParallelLayer):
|
||||
"""
|
||||
This class is used to isolate computation in a TP layer from the rest of the world.
|
||||
Parameters need to be LOCAL, so not dtensors
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh=None):
|
||||
# annotate module input placements/sharding with input_layouts
|
||||
input_tensor = inputs[0]
|
||||
if isinstance(input_tensor, DTensor):
|
||||
input_tensor = input_tensor.to_local()
|
||||
return input_tensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh=None):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
return outputs
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
distribute_module(
|
||||
module,
|
||||
device_mesh,
|
||||
partial(self._prepare_input_fn),
|
||||
partial(self._prepare_output_fn),
|
||||
)
|
||||
|
||||
|
||||
class ColwiseParallel(TensorParallelLayer):
|
||||
"""
|
||||
General tensor parallel layer for transformers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None,
|
||||
output_layouts: Optional[Placement] = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Replicate(),)
|
||||
self.output_layouts = (output_layouts or Shard(-1),)
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
# TODO: figure out dynamo support for instance method and switch this to instance method
|
||||
# annotate module input placements/sharding with input_layouts
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
# transform the input layouts to the desired layouts of ColwiseParallel
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
if param_type == "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Shard(-2)]
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
|
||||
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
# back to local tensor
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
|
||||
class PackedColwiseParallel(ColwiseParallel):
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -2)
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-2)], run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
|
||||
class RowwiseParallel(TensorParallelLayer):
|
||||
"""
|
||||
Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding.
|
||||
Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules.
|
||||
(i.e. MLP, Attention)
|
||||
|
||||
Keyword Args:
|
||||
input_layouts (Placement, optional):
|
||||
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
|
||||
become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension.
|
||||
output_layouts (Placement, optional):
|
||||
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
|
||||
with the user desired layout. If not specified, the output tensor is replicated.
|
||||
use_local_output (bool, optional):
|
||||
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
|
||||
Returns:
|
||||
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_layouts: Optional[Placement] = None,
|
||||
output_layouts: Optional[Placement] = None,
|
||||
use_local_output: bool = True,
|
||||
use_dtensor=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_layouts = (input_layouts or Shard(-1),)
|
||||
self.output_layouts = (output_layouts or Replicate(),)
|
||||
self.use_local_output = use_local_output
|
||||
self.use_dtensor = use_dtensor
|
||||
|
||||
@staticmethod
|
||||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
|
||||
input_tensor = inputs[0]
|
||||
if not isinstance(input_tensor, DTensor):
|
||||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
|
||||
|
||||
if input_layouts != desired_input_layouts:
|
||||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
|
||||
return input_tensor
|
||||
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
|
||||
# means Rowwise as nn.Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(0)
|
||||
if param_type != "bias":
|
||||
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
|
||||
shard = [Shard(-1)]
|
||||
else:
|
||||
shard = [Replicate()]
|
||||
parameter = param[:]
|
||||
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
|
||||
# Rowwise sharding produces partial output, depending on output layouts:
|
||||
# 1. to replicate -> allreduce
|
||||
# 2. to shard -> reduce_scatter
|
||||
if outputs.placements != output_layouts:
|
||||
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
|
||||
# back to local tensor if use_local_output is True
|
||||
return outputs.to_local() if use_local_output else outputs
|
||||
|
||||
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
|
||||
module._distribute_module_applied = True
|
||||
if self.use_dtensor:
|
||||
if isinstance(module, nn.Linear):
|
||||
# rowwise linear runtime sharding requires input tensor shard on last dim
|
||||
self.desired_input_layouts: Tuple[Placement, ...] = (Shard(-1),)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
# rowwise embedding runtime sharding requires input tensor replicated
|
||||
self.desired_input_layouts = (Replicate(),)
|
||||
elif isinstance(module, nn.Parameter):
|
||||
# rowwise embedding runtime sharding requires input tensor replicated
|
||||
self.desired_input_layouts = (Shard(-1),)
|
||||
else:
|
||||
raise NotImplementedError("RowwiseParallel currently only support nn.Linear and nn.Embedding!")
|
||||
|
||||
distribute_module(
|
||||
module,
|
||||
device_mesh,
|
||||
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
|
||||
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
|
||||
)
|
||||
|
||||
|
||||
class PackedRowwiseParallel(RowwiseParallel):
|
||||
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
|
||||
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
|
||||
# means Colwise as Linear is input * weight^T + bias, where
|
||||
# weight would become Shard(1)
|
||||
parameter = get_packed_weights(param, empty_param, device_mesh, rank, -1)
|
||||
parameter = parameter.to(param_casting_dtype)
|
||||
if to_contiguous:
|
||||
parameter = parameter.contiguous()
|
||||
if self.use_dtensor:
|
||||
parameter = DTensor.from_local(parameter, device_mesh, [Shard(-1)], run_check=False)
|
||||
return nn.Parameter(parameter)
|
||||
|
||||
|
||||
SUPPORTED_TP_STYLES = {
|
||||
"colwise",
|
||||
"rowwise",
|
||||
"colwise_rep",
|
||||
"rowwise_rep",
|
||||
"local_colwise",
|
||||
"local_rowwise",
|
||||
"local",
|
||||
"gather",
|
||||
"local_packed_rowwise",
|
||||
}
|
||||
|
||||
|
||||
@lru_cache
|
||||
def translate_to_torch_parallel_style(style: str):
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
styles, here we translate them into torch.distributed tensor-parallel
|
||||
types.
|
||||
"""
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
if style == "colwise":
|
||||
return ColwiseParallel()
|
||||
elif style == "rowwise":
|
||||
return RowwiseParallel()
|
||||
elif style == "colwise_rep":
|
||||
return ColwiseParallel(output_layouts=Replicate())
|
||||
elif style == "rowwise_rep":
|
||||
return RowwiseParallel(input_layouts=Replicate())
|
||||
elif style == "local_colwise":
|
||||
return ColwiseParallel(use_dtensor=False)
|
||||
elif style == "local_rowwise":
|
||||
return RowwiseParallel(use_dtensor=False)
|
||||
elif style == "local":
|
||||
return IsolatedParallel()
|
||||
elif style == "gather":
|
||||
return GatherParallel()
|
||||
elif style == "local_packed_rowwise":
|
||||
return PackedRowwiseParallel(use_dtensor=False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
|
||||
def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh):
|
||||
"""
|
||||
Add hooks to the module holding the layer. Meaning:
|
||||
```
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
self.layer = nn.Linear(10, 10)
|
||||
```
|
||||
has state_dict like:
|
||||
```
|
||||
{
|
||||
"layer.weight": torch.Tensor,
|
||||
"layer.bias": torch.Tensor
|
||||
}
|
||||
```
|
||||
we add hooks to `MyModel` as well as `layer` to make sure that the tensors are correctly sharded and gathered.
|
||||
"""
|
||||
|
||||
# 1. We add hooks to the layer being loaded:
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
tp_layer.prepare_module_tp(module, device_mesh)
|
||||
|
||||
# 2. We add hooks to the parrent module if needed
|
||||
if "." in layer_name:
|
||||
parrent_layer_name = layer_name.rsplit(".", 1)[0]
|
||||
generic_name = re.sub(r"\d+", "*", parrent_layer_name)
|
||||
# The module itself needs hooks
|
||||
if module_plan := tp_plan.get(generic_name, False):
|
||||
tp_layer = translate_to_torch_parallel_style(module_plan)
|
||||
module_to_tp_ = model.get_submodule(parrent_layer_name)
|
||||
tp_layer.prepare_module_tp(module_to_tp_, device_mesh)
|
||||
|
||||
|
||||
def shard_and_distribute_module(
|
||||
model, param, empty_param, parameter_name, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
):
|
||||
r"""
|
||||
Main uses cases:
|
||||
- column / rowise parallelism, you just shard all the weights of the layer (weight and bias)
|
||||
- packed layers: you slice the weights, then shard like above
|
||||
- custom operation:
|
||||
- you want to add an all-gather at the end of a local layer.
|
||||
- you want to have a layer that is isolated from the rest of the world (because torch.DTensor does not work well with `.view` for instance)
|
||||
|
||||
"""
|
||||
param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name
|
||||
tp_plan = model._tp_plan
|
||||
module_to_tp = model.get_submodule(param_name)
|
||||
current_module_plan = None
|
||||
generic_param_name = re.sub(r"\d+", "*", parameter_name)
|
||||
if generic_param_name in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name]
|
||||
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
|
||||
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]
|
||||
|
||||
# Add hooks to the module if not done yet
|
||||
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||
if not getattr(module_to_tp, "_is_hooked", False):
|
||||
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
|
||||
module_to_tp._is_hooked = True
|
||||
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
param = tp_layer.partition_tensor(
|
||||
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
|
||||
)
|
||||
else:
|
||||
param = param[:]
|
||||
if is_contiguous:
|
||||
param = param.contiguous()
|
||||
|
||||
# SUPER IMPORTANT we have to use setattr
|
||||
# otherwise loading is crazy slow
|
||||
if not isinstance(param, torch.nn.Parameter):
|
||||
param = torch.nn.Parameter(param)
|
||||
setattr(module_to_tp, param_type, param)
|
||||
# module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
return param
|
@ -54,17 +54,20 @@ from .integrations.deepspeed import _load_state_dict_into_zero3_model
|
||||
from .integrations.flash_attention import flash_attention_forward
|
||||
from .integrations.flex_attention import flex_attention_forward
|
||||
from .integrations.sdpa_attention import sdpa_attention_forward
|
||||
from .integrations.tensor_parallel import (
|
||||
SUPPORTED_TP_STYLES,
|
||||
shard_and_distribute_module,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
from .loss.loss_utils import LOSS_MAPPING
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
distribute_module,
|
||||
find_pruneable_heads_and_indices,
|
||||
id_tensor_storage,
|
||||
prune_conv1d_layer,
|
||||
prune_layer,
|
||||
prune_linear_layer,
|
||||
translate_to_torch_parallel_style,
|
||||
)
|
||||
from .quantizers import AutoHfQuantizer, HfQuantizer
|
||||
from .quantizers.quantizers_utils import get_module_from_name
|
||||
@ -151,6 +154,7 @@ logger = logging.get_logger(__name__)
|
||||
_init_weights = True
|
||||
_is_quantized = False
|
||||
_is_ds_init_called = False
|
||||
_torch_distributed_available = torch.distributed.is_available()
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
@ -181,8 +185,6 @@ else:
|
||||
if is_peft_available():
|
||||
from .utils import find_adapter_config_file
|
||||
|
||||
if is_torch_greater_or_equal("2.5"):
|
||||
from torch.distributed.tensor import DTensor, Shard
|
||||
|
||||
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
|
||||
|
||||
@ -756,6 +758,40 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
|
||||
def fix_tensor_type_and_device(
|
||||
model, param_name, param, dtype=None, keep_in_fp32_modules=None
|
||||
) -> Union[str, torch.dtype]:
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
if "." in param_name:
|
||||
pre, _ = param_name.rsplit(".", 1)
|
||||
|
||||
old_param = model.get_submodule(pre)
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
param_casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
|
||||
if param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
|
||||
param_casting_dtype = torch.float32
|
||||
elif dtype is not None:
|
||||
param_casting_dtype = dtype
|
||||
elif old_param is not None:
|
||||
param_casting_dtype = old_param.dtype
|
||||
return old_param is not None and old_param.is_contiguous(), param_casting_dtype
|
||||
else:
|
||||
return False, None
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _load_state_dict_into_meta_model(
|
||||
model: torch.nn.Module,
|
||||
@ -787,18 +823,12 @@ def _load_state_dict_into_meta_model(
|
||||
It also initialize tensor parallelism for each module if needed.
|
||||
|
||||
"""
|
||||
tensor_device = None
|
||||
tensor_device = "cpu"
|
||||
if device_map is not None and device_map.get("", None) is not None:
|
||||
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
|
||||
if device_map is not None:
|
||||
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
|
||||
|
||||
# we need this later to initialize tensor parallelism
|
||||
if device_mesh is not None:
|
||||
full_tp_plan = model.config.base_model_tp_plan
|
||||
for submodule in model.modules():
|
||||
full_tp_plan.update(getattr(submodule, "_tp_plan", {}))
|
||||
|
||||
file_pointer = None
|
||||
bin_state_dict = None
|
||||
if shard_file.endswith(".safetensors"):
|
||||
@ -818,8 +848,6 @@ def _load_state_dict_into_meta_model(
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
|
||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
||||
|
||||
for serialized_param_name, empty_param in state_dict.items():
|
||||
# serialized_param_name is the raw, serialized name
|
||||
# fixed_param_name is the model's equivalent
|
||||
@ -829,87 +857,37 @@ def _load_state_dict_into_meta_model(
|
||||
continue
|
||||
|
||||
# we need to use serialized_param_name as file pointer is untouched
|
||||
param = (
|
||||
file_pointer.get_slice(serialized_param_name)
|
||||
if shard_file.endswith(".safetensors")
|
||||
else bin_state_dict[serialized_param_name]
|
||||
if shard_file.endswith(".safetensors"):
|
||||
param = file_pointer.get_slice(serialized_param_name)
|
||||
elif shard_file.endswith(".gguf"):
|
||||
param = empty_param # For gguf the dict is actually not empty!
|
||||
else:
|
||||
param = bin_state_dict[serialized_param_name]
|
||||
|
||||
to_contiguous, param_casting_dtype = fix_tensor_type_and_device(
|
||||
model,
|
||||
param_name=fixed_param_name,
|
||||
param=empty_param,
|
||||
dtype=dtype,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
||||
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
||||
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
||||
|
||||
old_param = model
|
||||
splits = fixed_param_name.split(".")
|
||||
for split in splits:
|
||||
# We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys.
|
||||
old_param = getattr(old_param, split, None)
|
||||
if old_param is None:
|
||||
break
|
||||
|
||||
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
||||
old_param = None
|
||||
|
||||
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
param_casting_dtype = None
|
||||
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
||||
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name):
|
||||
param_casting_dtype = torch.float32
|
||||
elif dtype is not None:
|
||||
param_casting_dtype = dtype
|
||||
elif old_param is not None:
|
||||
param_casting_dtype = old_param.dtype
|
||||
|
||||
if device_mesh is not None: # In this case, the param is already on the correct device!
|
||||
module_to_tp, param_type = get_module_from_name(model, fixed_param_name)
|
||||
current_module_plan = None
|
||||
full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+")
|
||||
if plan := re.search(full_tp_plan_, fixed_param_name):
|
||||
match = re.sub("[0-9]+", "*", plan[0])
|
||||
current_module_plan = full_tp_plan[match]
|
||||
|
||||
if current_module_plan is not None:
|
||||
tp_layer = translate_to_torch_parallel_style(current_module_plan)
|
||||
rank = tensor_device
|
||||
row, col = empty_param.shape
|
||||
if "rowwise" == current_module_plan:
|
||||
param = param[:, rank * (col // device_mesh.size()) : (rank + 1) * (col // device_mesh.size())]
|
||||
shard = Shard(1)
|
||||
tp_layer.desired_input_layouts = (Shard(-1),)
|
||||
elif "colwise" == current_module_plan:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
else:
|
||||
param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :]
|
||||
shard = Shard(0)
|
||||
if param_casting_dtype is not None:
|
||||
param = param.to(param_casting_dtype)
|
||||
if old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
local_parameter = DTensor.from_local(
|
||||
param,
|
||||
device_mesh=device_mesh,
|
||||
placements=[shard] * device_mesh.ndim,
|
||||
)
|
||||
if isinstance(module_to_tp.weight, nn.Parameter):
|
||||
local_parameter = torch.nn.Parameter(local_parameter)
|
||||
module_to_tp.weight = local_parameter
|
||||
input_fn = partial(tp_layer._prepare_input_fn, tp_layer.input_layouts, tp_layer.desired_input_layouts)
|
||||
output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output)
|
||||
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
|
||||
else:
|
||||
param = param[:]
|
||||
if old_param is not None and old_param.is_contiguous():
|
||||
param = param.contiguous()
|
||||
module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True)
|
||||
|
||||
shard_and_distribute_module(
|
||||
model,
|
||||
param,
|
||||
empty_param,
|
||||
fixed_param_name,
|
||||
param_casting_dtype,
|
||||
to_contiguous,
|
||||
tensor_device, # the rank
|
||||
device_mesh,
|
||||
)
|
||||
else:
|
||||
param = param[:]
|
||||
if param_casting_dtype is not None:
|
||||
param = param.to(param_casting_dtype)
|
||||
if old_param is not None and old_param.is_contiguous():
|
||||
if to_contiguous:
|
||||
param = param.contiguous()
|
||||
|
||||
if device_map is None:
|
||||
@ -966,6 +944,7 @@ def _load_state_dict_into_meta_model(
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, param_type, value)
|
||||
|
||||
if file_pointer is not None:
|
||||
file_pointer.__exit__(None, None, None)
|
||||
|
||||
@ -1409,7 +1388,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# A tensor parallel plan to be applied to the model when TP is enabled. For
|
||||
# top-level models, this attribute is currently defined in respective model
|
||||
# code. For base models, this attribute comes from
|
||||
# `config.base_model_tp_plan` during `post_init`.
|
||||
# `config.base_model_tp_plan` during `__init__`.
|
||||
# It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
|
||||
# by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
|
||||
# for example.
|
||||
_tp_plan = None
|
||||
|
||||
# A pipeline parallel plan specifying the layers which may not be present
|
||||
@ -1475,6 +1457,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# when a different component (e.g. language_model) is used.
|
||||
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
|
||||
@ -1482,11 +1466,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
self.init_weights()
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
if self.base_model is self:
|
||||
self._tp_plan = self.config.base_model_tp_plan
|
||||
self._pp_plan = self.config.base_model_pp_plan
|
||||
|
||||
self._tp_plan = self._tp_plan or self.config.base_model_tp_plan or {}
|
||||
for name, module in self.named_children():
|
||||
if plan := getattr(module, "_tp_plan", None):
|
||||
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
|
||||
|
||||
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
|
||||
for _, v in self._tp_plan.items():
|
||||
if v not in SUPPORTED_TP_STYLES:
|
||||
raise ValueError(
|
||||
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
|
||||
)
|
||||
|
||||
def dequantize(self):
|
||||
"""
|
||||
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
||||
@ -4315,7 +4311,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if device_mesh is not None and not model.supports_tp_plan:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
if config.base_model_tp_plan is None and config.get_text_config().base_model_tp_plan is None:
|
||||
raise NotImplementedError("This model does not have a tensor parallel plan.")
|
||||
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = model.config
|
||||
@ -4453,7 +4450,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
model,
|
||||
state_dict,
|
||||
loaded_state_dict_keys, # XXX: rename?
|
||||
resolved_archive_file,
|
||||
resolved_archive_file or gguf_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
@ -4565,7 +4562,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
@staticmethod
|
||||
def _fix_state_dict_key_on_load(key) -> Tuple[str, bool]:
|
||||
"""Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
|
||||
|
||||
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
|
||||
# This rename is logged.
|
||||
if key.endswith("LayerNorm.beta"):
|
||||
@ -4590,6 +4586,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return key, False
|
||||
|
||||
def rename_key(self, key):
|
||||
"""
|
||||
When we load a LlamaModel from a checkpoint made using LlamaForCausalLM, the keys have an extra
|
||||
prefix, which can be accessed in the `LlamaModel` via the `self.base_model_prefix` attribute.
|
||||
|
||||
But, what if there is an extra layer on top of it? You load a MistralModel from a LlavaForConditionalGeneration?
|
||||
In that what you actually want is to cut whatever is left of the key.
|
||||
"""
|
||||
new_key = key
|
||||
if len(self.base_model_prefix) > 0:
|
||||
if not hasattr(self, self.base_model_prefix) and key.startswith(self.base_model_prefix):
|
||||
@ -4940,7 +4943,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
device_mesh=device_mesh,
|
||||
resolved_archive_file=resolved_archive_file,
|
||||
shard_file=resolved_archive_file,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
else:
|
||||
@ -5019,7 +5022,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
|
||||
model_to_load,
|
||||
state_dict,
|
||||
start_prefix,
|
||||
prefix,
|
||||
expected_keys,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
@ -5898,10 +5901,21 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
||||
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
param = getattr(model, param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += int(math.prod(param.shape) * allocation_factor)
|
||||
if "." in param_name:
|
||||
param_name, param_type = param_name.rsplit(".", 1)
|
||||
param = getattr(model.get_submodule(param_name), param_type)
|
||||
else:
|
||||
param = model.get_buffer(param_name)
|
||||
|
||||
param_size = int(math.prod(param.shape) * allocation_factor)
|
||||
|
||||
if _torch_distributed_available and torch.distributed.is_initialized():
|
||||
generic_name = re.sub(r"\d+", "*", param_name)
|
||||
param_size //= torch.distributed.get_world_size() if not model._tp_plan.get(generic_name, False) else 1
|
||||
|
||||
parameter_count[device] += param_size
|
||||
|
||||
dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
|
@ -419,7 +419,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
"OPTDecoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_keep_in_fp32_modules = ["wo"]
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -1799,7 +1799,7 @@ class Blip2Model(Blip2PreTrainedModel):
|
||||
)
|
||||
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||
supports_gradient_checkpointing = False
|
||||
_keep_in_fp32_modules = []
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -1898,7 +1898,7 @@ class Blip2TextModelWithProjection(Blip2PreTrainedModel):
|
||||
)
|
||||
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = []
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
@ -2371,7 +2371,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
|
||||
)
|
||||
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_keep_in_fp32_modules = []
|
||||
_keep_in_fp32_modules = ["query_tokens"]
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
|
@ -322,7 +322,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
|
||||
"InstructBlipQFormerMultiHeadAttention",
|
||||
"InstructBlipQFormerSelfOutput",
|
||||
]
|
||||
_keep_in_fp32_modules = []
|
||||
|
||||
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
|
||||
def _init_weights(self, module):
|
||||
@ -1293,6 +1292,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
|
||||
|
||||
def __init__(self, config: InstructBlipConfig):
|
||||
super().__init__(config)
|
||||
|
@ -323,7 +323,6 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
|
||||
"InstructBlipVideoQFormerMultiHeadAttention",
|
||||
"InstructBlipVideoQFormerSelfOutput",
|
||||
]
|
||||
_keep_in_fp32_modules = []
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -1287,6 +1286,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
|
||||
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
|
||||
|
||||
def __init__(self, config: InstructBlipVideoConfig):
|
||||
super().__init__(config)
|
||||
|
@ -58,13 +58,13 @@ class LlavaConfigTest(unittest.TestCase):
|
||||
"""
|
||||
Simple test for reloading arbirarily composed subconfigs
|
||||
"""
|
||||
default_values = LlavaConfig().to_dict()
|
||||
default_values["vision_config"]["model_type"] = "qwen2_vl"
|
||||
default_values = LlavaConfig().to_diff_dict()
|
||||
default_values["vision_config"]["model_type"] = "pixtral"
|
||||
default_values["text_config"]["model_type"] = "opt"
|
||||
|
||||
self.maxDiff = None
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config = LlavaConfig(**default_values)
|
||||
config.save_pretrained(tmp_dir)
|
||||
|
||||
reloaded = LlavaConfig.from_pretrained(tmp_dir)
|
||||
assert config.to_dict() == reloaded.to_dict()
|
||||
self.assertDictEqual(config.to_dict(), reloaded.to_dict())
|
||||
|
Loading…
Reference in New Issue
Block a user