[Regression] Fix Quark quantized model loading after refactorization (#37407)

This commit is contained in:
Bowen Bao 2025-04-11 04:43:36 -07:00 committed by GitHub
parent a563999a02
commit 6cef03ba66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 9 deletions

View File

@ -649,7 +649,10 @@ def _infer_parameter_dtype(
try:
old_param = model.get_parameter_or_buffer(param_name)
except Exception as e:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}:
return True, None
else:
raise e
@ -708,11 +711,12 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
is_quantized = hf_quantizer is not None
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
QuantizationMethod.QUARK,
}
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark
file_pointer = None
if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
@ -4632,11 +4636,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
):
# Useful flags
is_quantized = hf_quantizer is not None
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
]
QuantizationMethod.QUARK,
}
# Get all the keys of the state dicts that we have to initialize the model
if sharded_metadata is not None:
@ -4798,7 +4806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
# Warmup cuda to load the weights much faster on devices
if device_map is not None and not is_hqq:
if device_map is not None and not is_hqq_or_quark:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
@ -4812,7 +4820,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not is_hqq_or_bnb_or_quark
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"

View File

@ -53,6 +53,7 @@ class QuarkTest(unittest.TestCase):
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
EXPECTED_OUTPUTS.add("Today I am in Paris and I'm here to tell you about it. It's a beautiful day,")
EXPECTED_RELATIVE_DIFFERENCE = 1.66
device_map = None