From dd3933dd658b2c2e18ad316662a3dff09dcf98cb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 21 Mar 2025 16:12:59 +0100 Subject: [PATCH] Simplify keep_in_fp32_modules logic (#36722) * better regex everywhere * fix * Update test_modeling_instructblip.py * BC with explanations this time otherwise it makes no sense at all * Update test_modeling_instructblip.py * style * CIs * update _keep_in_fp32_modules in blip2 * Update modeling_utils.py * Update modeling_utils.py * style * CIs * add check * trigger CIs * Update modeling_utils.py * trigger CIs --- src/transformers/modeling_utils.py | 72 ++++++++++++------- .../models/blip_2/modeling_blip_2.py | 3 +- .../test_modeling_instructblip.py | 3 - 3 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2c28ec8f1e5..7747003aeb4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -716,7 +716,7 @@ def _infer_parameter_dtype( model: "PreTrainedModel", param_name: str, empty_param: torch.Tensor, - keep_in_fp32_modules: Optional[List[str]] = None, + keep_in_fp32_regex: Optional[re.Pattern] = None, hf_quantizer: Optional[HfQuantizer] = None, ) -> Union[bool, Optional[torch.dtype]]: try: @@ -733,7 +733,7 @@ def _infer_parameter_dtype( is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: # First fp32 if part of the exception list - if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name): + if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name): casting_dtype = torch.float32 # Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes elif hf_quantizer is not None: @@ -757,7 +757,7 @@ def _load_state_dict_into_meta_model( cpu_offload_index: Optional[Dict] = None, hf_quantizer: Optional[HfQuantizer] = None, is_safetensors: bool = False, - keep_in_fp32_modules: Optional[List[str]] = None, + keep_in_fp32_regex: Optional[re.Pattern] = None, unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, ) -> Tuple[Optional[Dict], Optional[Dict]]: @@ -795,7 +795,7 @@ def _load_state_dict_into_meta_model( model, param_name, empty_param, - keep_in_fp32_modules, + keep_in_fp32_regex, hf_quantizer, ) @@ -1284,7 +1284,7 @@ def _get_device_map( max_memory: Optional[Dict], hf_quantizer: Optional[HfQuantizer], torch_dtype: Optional[torch.dtype], - keep_in_fp32_modules: Optional[List[str]], + keep_in_fp32_regex: Optional[re.Pattern], ) -> Dict: """Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential']. Otherwise, we check for any device inconsistencies in the device_map. @@ -1293,13 +1293,9 @@ def _get_device_map( special_dtypes = {} if hf_quantizer is not None: special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) - if keep_in_fp32_modules is not None: + if keep_in_fp32_regex is not None: special_dtypes.update( - { - name: torch.float32 - for name, _ in model.named_parameters() - if any(m in name for m in keep_in_fp32_modules) - } + {name: torch.float32 for name, _ in model.named_parameters() if keep_in_fp32_regex.search(name)} ) target_dtype = torch_dtype @@ -1911,6 +1907,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix self.init_weights() self._backward_compatibility_gradient_checkpointing() + # Make sure the modules correctly exist if the flag is active + if self._keep_in_fp32_modules is not None: + all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0} + unique_module_names = set() + # Get all unique module names in the module graph, without the prefixes + for param in all_parameters: + unique_module_names.update( + [name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]] + ) + # Check that every module in the keep_in_fp32 list is part of the module graph + for module in self._keep_in_fp32_modules: + if module not in unique_module_names: + raise ValueError( + f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in" + f" {self.__class__.__name__}" + ) + # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config if self.base_model is self: self._pp_plan = ( @@ -4412,15 +4425,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix config = model.config # Find fp32 modules if needed - keep_in_fp32_modules = None - if model._keep_in_fp32_modules is not None: - if is_accelerate_available() and not is_deepspeed_zero3_enabled(): - low_cpu_mem_usage = True - keep_in_fp32_modules = model._keep_in_fp32_modules if len(model._keep_in_fp32_modules) > 0 else None + keep_in_fp32_regex = None + # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced + # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing + # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details. + if model._keep_in_fp32_modules is not None and ( + torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) + ): + # Only the path with `low_cpu_mem_usage` will check every param for the correct dtype + low_cpu_mem_usage = True + # We need to match exact layers, so we add either `.` on each side, or start/end of string + keep_in_fp32_regex = re.compile( + "|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules]) + ) if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules ) # We store the original dtype for quantized models as we cannot easily retrieve it @@ -4431,9 +4452,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Prepare the full device map if device_map is not None: - device_map = _get_device_map( - model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_modules - ) + device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, torch_dtype, keep_in_fp32_regex) # Finalize model weight initialization if from_tf: @@ -4465,7 +4484,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix offload_state_dict=offload_state_dict, dtype=torch_dtype, hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, + keep_in_fp32_regex=keep_in_fp32_regex, device_mesh=device_mesh, key_mapping=key_mapping, weights_only=weights_only, @@ -4674,7 +4693,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix offload_state_dict: Optional[bool] = None, dtype: Optional[torch.dtype] = None, hf_quantizer: Optional[HfQuantizer] = None, - keep_in_fp32_modules: Optional[List[str]] = None, + keep_in_fp32_regex: Optional[re.Pattern] = None, device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, key_mapping: Optional[Dict[str, str]] = None, weights_only: bool = True, @@ -4736,10 +4755,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized) # Set some modules to fp32 if needed - if keep_in_fp32_modules is not None: - keep_in_fp32_modules = re.compile("|".join([re.escape(module) for module in keep_in_fp32_modules])) + if keep_in_fp32_regex is not None: for name, param in model.named_parameters(): - if keep_in_fp32_modules.search(name): + if keep_in_fp32_regex.search(name): # param = param.to(torch.float32) does not work here as only in the local scope. param.data = param.data.to(torch.float32) @@ -4894,7 +4912,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix cpu_offload_index=cpu_offload_index, hf_quantizer=hf_quantizer, is_safetensors=is_offloaded_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, + keep_in_fp32_regex=keep_in_fp32_regex, unexpected_keys=unexpected_keys, device_mesh=device_mesh, ) @@ -4951,7 +4969,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix } for name, param in parameters_to_initialize.items(): # First move data to correct - to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_modules) + to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex) shard_and_distribute_module( model, param.to(tp_device), diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 7a6475c152d..ab5a2a9abd6 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -419,7 +419,6 @@ class Blip2PreTrainedModel(PreTrainedModel): "OPTDecoderLayer", ] _skip_keys_device_placement = "past_key_values" - _keep_in_fp32_modules = ["query_tokens"] def _init_weights(self, module): """Initialize the weights""" @@ -1448,6 +1447,7 @@ class Blip2QFormerModel(Blip2PreTrainedModel): class Blip2Model(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" + _keep_in_fp32_modules = ["query_tokens"] def __init__(self, config: Blip2Config): super().__init__(config) @@ -2019,6 +2019,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): _supports_cache_class = True _supports_static_cache = True _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) + _keep_in_fp32_modules = ["query_tokens"] def __init__(self, config: Blip2Config): super().__init__(config) diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index bccc8e230e7..e9d325460d5 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -791,15 +791,12 @@ class InstructBlipModelIntegrationTest(unittest.TestCase): num_beams=5, max_length=256, min_length=1, - top_p=0.9, repetition_penalty=1.5, length_penalty=1.0, temperature=1, ) generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] - expected_outputs = [0, 37, 1023, 9850, 7, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4459, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 37, 388, 19, 5119, 3, 9, 4459, 8677, 28, 3, 9, 2756, 4459, 6177, 6, 11, 3, 88, 19, 338, 46, 3575, 53, 1476, 12, 743, 112, 2491, 5, 37, 1023, 19, 7225, 788, 12, 8, 685, 24, 34, 1267, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 94, 19, 487, 24, 8, 388, 19, 1119, 12, 1097, 540, 57, 692, 112, 10428, 30, 8, 223, 13, 8, 4049, 6, 68, 34, 19, 92, 487, 24, 3, 88, 19, 1119, 12, 1097, 97, 57, 692, 112, 10428, 30, 8, 223, 13, 8, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 3, 13865, 13, 8, 1053, 21, 8, 388, 31, 7, 2874, 6, 34, 19, 964, 24, 3, 88, 19, 1119, 12, 1097, 97, 57, 692, 112, 10428, 30, 8, 223, 13, 8, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 1] # fmt: skip - expected_outputs = [0, 37, 7225, 1023, 9850, 7, 3, 9, 388, 3575, 53, 4954, 30, 8, 223, 13, 3, 9, 4459, 4049, 16, 8, 2214, 13, 3, 9, 3164, 690, 2815, 5, 37, 388, 19, 5119, 3, 9, 4459, 8677, 28, 46, 3575, 53, 1476, 5223, 12, 34, 6, 15495, 24, 3, 88, 19, 692, 112, 293, 10428, 44, 234, 1066, 145, 338, 3, 9, 50, 1106, 3522, 144, 42, 2192, 7919, 31, 7, 5, 37, 1023, 92, 1267, 3, 9, 381, 13, 119, 3203, 16, 8, 2458, 6, 379, 14264, 6, 9256, 7, 6, 11, 11718, 7, 5, 1] # fmt: skip self.assertEqual(outputs[0].tolist(), expected_outputs)