diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6a76accf569..697b3deb14d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2547,6 +2547,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) >= version.parse("0.37.0") if isinstance(device_map, str): + special_dtypes = { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } if model._no_split_modules is None: raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.") no_split_modules = model._no_split_modules @@ -2557,22 +2562,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None: raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.") + + kwargs = {"no_split_module_classes": no_split_modules, "max_memory": max_memory} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warn( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) if device_map != "sequential" and get_balanced_memory is not None: max_memory = get_balanced_memory( model, - max_memory=max_memory, - no_split_module_classes=no_split_modules, dtype=torch_dtype, low_zero=(device_map == "balanced_low_0"), + **kwargs, ) # Make sure tied weights are tied before creating the device map. model.tie_weights() - device_map = infer_auto_device_map( - model, - no_split_module_classes=no_split_modules, - dtype=torch_dtype if not load_in_8bit else torch.int8, - max_memory=max_memory, - ) + device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs) if load_in_8bit: # The LM head / tied weights or any last module can stay on disk / CPU