mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Fix init empty weights
without accelerate (#37337)
* add the integration * Update accelerate.py * Update accelerate.py * add find_tied_params as well * Update accelerate.py * add where copied from * simplify * add error
This commit is contained in:
parent
9db31ea585
commit
08f36771b3
196
src/transformers/integrations/accelerate.py
Normal file
196
src/transformers/integrations/accelerate.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta
|
||||||
|
device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is
|
||||||
|
somewhat still a soft dependency, we copy the functions here to be used natively in Transformers.
|
||||||
|
|
||||||
|
The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the
|
||||||
|
`find_tied_parameters` was copied from `accelerate.utils.modeling.py`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from ..utils import is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def init_empty_weights(include_buffers: bool = False):
|
||||||
|
"""
|
||||||
|
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
||||||
|
empty model. Useful when just initializing the model would blow the available RAM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_buffers (`bool`, *optional*):
|
||||||
|
Whether or not to also put all buffers on the meta device while initializing.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
||||||
|
with init_empty_weights():
|
||||||
|
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
||||||
|
```
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
Any model created under this context manager has no weights. As such you can't do something like
|
||||||
|
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
||||||
|
Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
|
||||||
|
called.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
|
||||||
|
yield f
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def init_on_device(device: "torch.device", include_buffers: bool = False):
|
||||||
|
"""
|
||||||
|
A context manager under which models are initialized with all parameters on the specified device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (`torch.device`):
|
||||||
|
Device to initialize all parameters on.
|
||||||
|
include_buffers (`bool`, *optional*):
|
||||||
|
Whether or not to also put all buffers on the meta device while initializing.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch.nn as nn
|
||||||
|
from accelerate import init_on_device
|
||||||
|
|
||||||
|
with init_on_device(device=torch.device("cuda")):
|
||||||
|
tst = nn.Linear(100, 100) # on `cuda` device
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if include_buffers:
|
||||||
|
with device:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
old_register_parameter = nn.Module.register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
old_register_buffer = nn.Module.register_buffer
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||||
|
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||||
|
if buffer is not None:
|
||||||
|
module._buffers[name] = module._buffers[name].to(device)
|
||||||
|
|
||||||
|
# Patch tensor creation
|
||||||
|
if include_buffers:
|
||||||
|
tensor_constructors_to_patch = {
|
||||||
|
torch_function_name: getattr(torch, torch_function_name)
|
||||||
|
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tensor_constructors_to_patch = {}
|
||||||
|
|
||||||
|
def patch_tensor_constructor(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
kwargs["device"] = device
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
try:
|
||||||
|
nn.Module.register_parameter = register_empty_parameter
|
||||||
|
if include_buffers:
|
||||||
|
nn.Module.register_buffer = register_empty_buffer
|
||||||
|
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||||
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nn.Module.register_parameter = old_register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
nn.Module.register_buffer = old_register_buffer
|
||||||
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||||
|
setattr(torch, torch_function_name, old_torch_function)
|
||||||
|
|
||||||
|
|
||||||
|
def find_tied_parameters(model: "nn.Module", **kwargs):
|
||||||
|
"""
|
||||||
|
Find the tied parameters in a given model.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
||||||
|
them.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`torch.nn.Module`): The model to inspect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[str]]: A list of lists of parameter names being all tied together.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from collections import OrderedDict
|
||||||
|
>>> import torch.nn as nn
|
||||||
|
|
||||||
|
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
||||||
|
>>> model.linear2.weight = model.linear1.weight
|
||||||
|
>>> find_tied_parameters(model)
|
||||||
|
[['linear1.weight', 'linear2.weight']]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
# get ALL model parameters and thier names
|
||||||
|
all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
|
||||||
|
|
||||||
|
# get ONLY unique named parameters,
|
||||||
|
# if parameter is tied and have multiple names, it will be included only once
|
||||||
|
no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
|
||||||
|
|
||||||
|
# the difference of the two sets will give us the tied parameters
|
||||||
|
tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
|
||||||
|
|
||||||
|
# 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
|
||||||
|
# which names refer to the same parameter. To identify this, we need to group them together.
|
||||||
|
tied_param_groups = {}
|
||||||
|
for tied_param_name in tied_param_names:
|
||||||
|
tied_param = all_named_parameters[tied_param_name]
|
||||||
|
for param_name, param in no_duplicate_named_parameters.items():
|
||||||
|
# compare if parameters are the same, if so, group thier names together
|
||||||
|
if param is tied_param:
|
||||||
|
if param_name not in tied_param_groups:
|
||||||
|
tied_param_groups[param_name] = []
|
||||||
|
tied_param_groups[param_name].append(tied_param_name)
|
||||||
|
|
||||||
|
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
|
@ -57,6 +57,7 @@ from .configuration_utils import PretrainedConfig
|
|||||||
from .dynamic_module_utils import custom_object_save
|
from .dynamic_module_utils import custom_object_save
|
||||||
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
from .generation import CompileConfig, GenerationConfig, GenerationMixin
|
||||||
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
|
||||||
|
from .integrations.accelerate import find_tied_parameters, init_empty_weights
|
||||||
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
|
||||||
from .integrations.flash_attention import flash_attention_forward
|
from .integrations.flash_attention import flash_attention_forward
|
||||||
from .integrations.flex_attention import flex_attention_forward
|
from .integrations.flex_attention import flex_attention_forward
|
||||||
@ -131,12 +132,11 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
|
|||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
|
from accelerate import dispatch_model, infer_auto_device_map
|
||||||
from accelerate.hooks import add_hook_to_module
|
from accelerate.hooks import add_hook_to_module
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
check_tied_parameters_on_same_device,
|
check_tied_parameters_on_same_device,
|
||||||
extract_model_from_parallel,
|
extract_model_from_parallel,
|
||||||
find_tied_parameters,
|
|
||||||
get_balanced_memory,
|
get_balanced_memory,
|
||||||
get_max_memory,
|
get_max_memory,
|
||||||
load_offloaded_weights,
|
load_offloaded_weights,
|
||||||
@ -4135,6 +4135,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if device_map is not None:
|
if device_map is not None:
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
|
||||||
|
if not is_accelerate_available():
|
||||||
|
raise ValueError(
|
||||||
|
"Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`"
|
||||||
|
)
|
||||||
|
|
||||||
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
|
||||||
if load_in_4bit or load_in_8bit:
|
if load_in_4bit or load_in_8bit:
|
||||||
|
Loading…
Reference in New Issue
Block a user