mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Regression] Fix Quark quantized model loading after refactorization (#37407)
This commit is contained in:
parent
a563999a02
commit
6cef03ba66
@ -649,7 +649,10 @@ def _infer_parameter_dtype(
|
||||
try:
|
||||
old_param = model.get_parameter_or_buffer(param_name)
|
||||
except Exception as e:
|
||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ:
|
||||
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
}:
|
||||
return True, None
|
||||
else:
|
||||
raise e
|
||||
@ -708,11 +711,12 @@ def _load_state_dict_into_meta_model(
|
||||
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
|
||||
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
|
||||
QuantizationMethod.QUARK,
|
||||
}
|
||||
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_quark
|
||||
file_pointer = None
|
||||
if is_meta_state_dict:
|
||||
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
|
||||
@ -4632,11 +4636,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
):
|
||||
# Useful flags
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_hqq = is_quantized and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
|
||||
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in [
|
||||
is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.QUARK,
|
||||
}
|
||||
is_hqq_or_bnb_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
|
||||
QuantizationMethod.HQQ,
|
||||
QuantizationMethod.BITS_AND_BYTES,
|
||||
]
|
||||
QuantizationMethod.QUARK,
|
||||
}
|
||||
|
||||
# Get all the keys of the state dicts that we have to initialize the model
|
||||
if sharded_metadata is not None:
|
||||
@ -4798,7 +4806,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
|
||||
|
||||
# Warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None and not is_hqq:
|
||||
if device_map is not None and not is_hqq_or_quark:
|
||||
expanded_device_map = expand_device_map(device_map, expected_keys)
|
||||
caching_allocator_warmup(model_to_load, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
@ -4812,7 +4820,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
map_location = "cpu"
|
||||
if (
|
||||
shard_file.endswith(".safetensors")
|
||||
and not is_hqq_or_bnb
|
||||
and not is_hqq_or_bnb_or_quark
|
||||
and not (is_deepspeed_zero3_enabled() and not is_quantized)
|
||||
):
|
||||
map_location = "meta"
|
||||
|
@ -53,6 +53,7 @@ class QuarkTest(unittest.TestCase):
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am not in Paris, France\nToday I am in Paris, Illinois")
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying the city of light. I am not just any ordinary Paris")
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I am enjoying my day off! The sun is shining, the birds are")
|
||||
EXPECTED_OUTPUTS.add("Today I am in Paris and I'm here to tell you about it. It's a beautiful day,")
|
||||
|
||||
EXPECTED_RELATIVE_DIFFERENCE = 1.66
|
||||
device_map = None
|
||||
|
Loading…
Reference in New Issue
Block a user