mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Add support for weights_only
flag when loading state_dict (#32481)
* Add support for `weights_only` flag when loading state_dict Summary: This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to for quantized weights) Test Plan: tested locally with torchao weights, also need https://github.com/huggingface/transformers/pull/32306: ``` import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig from torchao.utils import benchmark_model import torchao DEVICE_TYPE = "cuda" def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None): tokenizer = AutoTokenizer.from_pretrained(model_id) if quantization_config is not None: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config) else: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False) # sanity check: run the model input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE) output = model.generate(**input_ids, max_new_tokens=1000) print(tokenizer.decode(output[0], skip_special_tokens=True)) NUM_WARMUP = 1 NUM_RUNS = 5 if quantization_config is not None: torchao.quantization.utils.recommended_inductor_config_setter() model = torch.compile(model, mode="max-autotune") benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE) print("running benchmark") results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE) return model, results model_id = "jerryzh168/test-model" torchao.quantization.utils.recommended_inductor_config_setter() bf16_model, bf16_time = init_model_and_benchmark(model_id) print(f"bf16: {bf16_time}") ``` Reviewers: Subscribers: Tasks: Tags: * format
This commit is contained in:
parent
a220c5b99f
commit
15a4d24805
@ -544,6 +544,7 @@ def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
is_quantized: bool = False,
|
||||
map_location: Optional[Union[str, torch.device]] = None,
|
||||
weights_only: bool = True,
|
||||
):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
@ -580,7 +581,7 @@ def load_state_dict(
|
||||
and is_zipfile(checkpoint_file)
|
||||
):
|
||||
extra_args = {"mmap": True}
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location=map_location,
|
||||
@ -3009,6 +3010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
revision: str = "main",
|
||||
use_safetensors: bool = None,
|
||||
weights_only: bool = True,
|
||||
**kwargs,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
@ -3196,6 +3198,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
|
||||
is not installed, it will be set to `False`.
|
||||
|
||||
weights_only (`bool`, *optional*, defaults to `True`):
|
||||
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
|
||||
dictionaries and any types added via torch.serialization.add_safe_globals().
|
||||
When set to False, we can load wrapper tensor subclass weights.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
@ -3831,7 +3838,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if from_pt:
|
||||
if not is_sharded and state_dict is None:
|
||||
# Time to load the checkpoint
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
@ -3852,7 +3859,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif not is_sharded:
|
||||
torch_dtype = get_state_dict_dtype(state_dict)
|
||||
else:
|
||||
one_state_dict = load_state_dict(resolved_archive_file[0])
|
||||
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
|
||||
torch_dtype = get_state_dict_dtype(one_state_dict)
|
||||
del one_state_dict # free CPU memory
|
||||
logger.info(
|
||||
@ -4052,6 +4059,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
gguf_path=gguf_path,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
# make sure token embedding weights are still tied if needed
|
||||
@ -4157,6 +4165,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
gguf_path=None,
|
||||
weights_only=True,
|
||||
):
|
||||
is_safetensors = False
|
||||
is_quantized = hf_quantizer is not None
|
||||
@ -4514,7 +4523,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location)
|
||||
state_dict = load_state_dict(
|
||||
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||
)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
@ -4667,6 +4678,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
start_prefix="",
|
||||
hf_quantizer=None,
|
||||
pretrained_model_name_or_path=None,
|
||||
weights_only=True,
|
||||
):
|
||||
"""
|
||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||
@ -4687,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"""
|
||||
|
||||
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
||||
error_msgs = _load_state_dict_into_meta_model(
|
||||
model,
|
||||
|
Loading…
Reference in New Issue
Block a user