From 14b597f51837284f92c1753b2332e05d959bab1d Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 18 Mar 2025 18:46:03 +0100 Subject: [PATCH] Fix casting dtype for qunatization (#36799) * fix * remove print --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d4c9815bd34..4158c82b409 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -732,6 +732,8 @@ def _infer_parameter_dtype( if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name): casting_dtype = torch.float32 # Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes + elif hf_quantizer is not None: + casting_dtype = model.config._pre_quantization_dtype else: casting_dtype = old_param.dtype return old_param is not None and old_param.is_contiguous(), casting_dtype @@ -754,7 +756,6 @@ def _load_state_dict_into_meta_model( keep_in_fp32_modules: Optional[List[str]] = None, unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, - weights_only=True, ) -> Tuple[Optional[Dict], Optional[Dict]]: """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded @@ -4883,7 +4884,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix keep_in_fp32_modules=keep_in_fp32_modules, unexpected_keys=unexpected_keys, device_mesh=device_mesh, - weights_only=weights_only, ) else: assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)