From b45192ec47b7aed6ba3e3c1b00d33e547e468e9b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 14 Mar 2023 09:20:16 -0400 Subject: [PATCH] Fix big model inference for T5 models in float16 (#22095) * Fix big model inference for T5 models in float16 * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Style * Trigger CI with latest release --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/modeling_utils.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6a76accf569..697b3deb14d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2547,6 +2547,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) >= version.parse("0.37.0") if isinstance(device_map, str): + special_dtypes = { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } if model._no_split_modules is None: raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.") no_split_modules = model._no_split_modules @@ -2557,22 +2562,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None: raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.") + + kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warn( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) if device_map != "sequential" and get_balanced_memory is not None: max_memory = get_balanced_memory( model, - max_memory=max_memory, - no_split_module_classes=no_split_modules, dtype=torch_dtype, low_zero=(device_map == "balanced_low_0"), + **kwargs, ) # Make sure tied weights are tied before creating the device map. model.tie_weights() - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_modules, - dtype=torch_dtype if not load_in_8bit else torch.int8, - max_memory=max_memory, - ) + device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs) if load_in_8bit: # The LM head / tied weights or any last module can stay on disk / CPU