Fix casting dtype for qunatization (#36799)

* fix

* remove print
This commit is contained in:
Marc Sun 2025-03-18 18:46:03 +01:00 committed by GitHub
parent 30580f035b
commit 14b597f518
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -732,6 +732,8 @@ def _infer_parameter_dtype(
if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(param_name):
casting_dtype = torch.float32
# Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
elif hf_quantizer is not None:
casting_dtype = model.config._pre_quantization_dtype
else:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype
@ -754,7 +756,6 @@ def _load_state_dict_into_meta_model(
keep_in_fp32_modules: Optional[List[str]] = None,
unexpected_keys: Optional[List[str]] = None, # passing `unexpected` for cleanup from quantization items
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
weights_only=True,
) -> Tuple[Optional[Dict], Optional[Dict]]:
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@ -4883,7 +4884,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
weights_only=weights_only,
)
else:
assign_params = check_support_param_buffer_assignment(model_to_load, state_dict)