From 08f36771b33d246986d9338a729fc4ef258b999d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 7 Apr 2025 11:37:29 +0200 Subject: [PATCH] Fix `init empty weights` without accelerate (#37337) * add the integration * Update accelerate.py * Update accelerate.py * add find_tied_params as well * Update accelerate.py * add where copied from * simplify * add error --- src/transformers/integrations/accelerate.py | 196 ++++++++++++++++++++ src/transformers/modeling_utils.py | 8 +- 2 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 src/transformers/integrations/accelerate.py diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py new file mode 100644 index 00000000000..83efac9661a --- /dev/null +++ b/src/transformers/integrations/accelerate.py @@ -0,0 +1,196 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta +device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is +somewhat still a soft dependency, we copy the functions here to be used natively in Transformers. + +The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the +`find_tied_parameters` was copied from `accelerate.utils.modeling.py` +""" + +from contextlib import contextmanager + +from ..utils import is_torch_available, logging + + +if is_torch_available(): + import torch + import torch.nn as nn + + +logger = logging.get_logger(__name__) + + +@contextmanager +def init_empty_weights(include_buffers: bool = False): + """ + A context manager under which models are initialized with all parameters on the meta device, therefore creating an + empty model. Useful when just initializing the model would blow the available RAM. + + Args: + include_buffers (`bool`, *optional*): + Whether or not to also put all buffers on the meta device while initializing. + + Example: + + ```python + import torch.nn as nn + from accelerate import init_empty_weights + + # Initialize a model with 100 billions parameters in no time and without using any RAM. + with init_empty_weights(): + tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) + ``` + + + + Any model created under this context manager has no weights. As such you can't do something like + `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. + Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not + called. + + + """ + with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: + yield f + + +@contextmanager +def init_on_device(device: "torch.device", include_buffers: bool = False): + """ + A context manager under which models are initialized with all parameters on the specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + include_buffers (`bool`, *optional*): + Whether or not to also put all buffers on the meta device while initializing. + + Example: + + ```python + import torch.nn as nn + from accelerate import init_on_device + + with init_on_device(device=torch.device("cuda")): + tst = nn.Linear(100, 100) # on `cuda` device + ``` + """ + if include_buffers: + with device: + yield + return + + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + # Patch tensor creation + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + if include_buffers: + nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + nn.Module.register_parameter = old_register_parameter + if include_buffers: + nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + + +def find_tied_parameters(model: "nn.Module", **kwargs): + """ + Find the tied parameters in a given model. + + + + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore + them. + + + + Args: + model (`torch.nn.Module`): The model to inspect. + + Returns: + List[List[str]]: A list of lists of parameter names being all tied together. + + Example: + + ```py + >>> from collections import OrderedDict + >>> import torch.nn as nn + + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] + ``` + """ + + # get ALL model parameters and thier names + all_named_parameters = dict(model.named_parameters(remove_duplicate=False)) + + # get ONLY unique named parameters, + # if parameter is tied and have multiple names, it will be included only once + no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True)) + + # the difference of the two sets will give us the tied parameters + tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) + + # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know + # which names refer to the same parameter. To identify this, we need to group them together. + tied_param_groups = {} + for tied_param_name in tied_param_names: + tied_param = all_named_parameters[tied_param_name] + for param_name, param in no_duplicate_named_parameters.items(): + # compare if parameters are the same, if so, group thier names together + if param is tied_param: + if param_name not in tied_param_groups: + tied_param_groups[param_name] = [] + tied_param_groups[param_name].append(tied_param_name) + + return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 67266d558dd..a462d1e348e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -57,6 +57,7 @@ from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.accelerate import find_tied_parameters, init_empty_weights from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward @@ -131,12 +132,11 @@ XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() if is_accelerate_available(): - from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights + from accelerate import dispatch_model, infer_auto_device_map from accelerate.hooks import add_hook_to_module from accelerate.utils import ( check_tied_parameters_on_same_device, extract_model_from_parallel, - find_tied_parameters, get_balanced_memory, get_max_memory, load_offloaded_weights, @@ -4135,6 +4135,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if device_map is not None: if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.") + if not is_accelerate_available(): + raise ValueError( + "Using a `device_map` or `tp_plan` requires `accelerate`. You can install it with `pip install accelerate`" + ) # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation. if load_in_4bit or load_in_8bit: