mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Add support for weights_only
flag when loading state_dict (#32481)
* Add support for `weights_only` flag when loading state_dict Summary: This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to for quantized weights) Test Plan: tested locally with torchao weights, also need https://github.com/huggingface/transformers/pull/32306: ``` import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig from torchao.utils import benchmark_model import torchao DEVICE_TYPE = "cuda" def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None): tokenizer = AutoTokenizer.from_pretrained(model_id) if quantization_config is not None: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config) else: model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False) # sanity check: run the model input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE) output = model.generate(**input_ids, max_new_tokens=1000) print(tokenizer.decode(output[0], skip_special_tokens=True)) NUM_WARMUP = 1 NUM_RUNS = 5 if quantization_config is not None: torchao.quantization.utils.recommended_inductor_config_setter() model = torch.compile(model, mode="max-autotune") benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE) print("running benchmark") results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE) return model, results model_id = "jerryzh168/test-model" torchao.quantization.utils.recommended_inductor_config_setter() bf16_model, bf16_time = init_model_and_benchmark(model_id) print(f"bf16: {bf16_time}") ``` Reviewers: Subscribers: Tasks: Tags: * format
This commit is contained in:
parent
a220c5b99f
commit
15a4d24805
@ -544,6 +544,7 @@ def load_state_dict(
|
|||||||
checkpoint_file: Union[str, os.PathLike],
|
checkpoint_file: Union[str, os.PathLike],
|
||||||
is_quantized: bool = False,
|
is_quantized: bool = False,
|
||||||
map_location: Optional[Union[str, torch.device]] = None,
|
map_location: Optional[Union[str, torch.device]] = None,
|
||||||
|
weights_only: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||||
@ -580,7 +581,7 @@ def load_state_dict(
|
|||||||
and is_zipfile(checkpoint_file)
|
and is_zipfile(checkpoint_file)
|
||||||
):
|
):
|
||||||
extra_args = {"mmap": True}
|
extra_args = {"mmap": True}
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
|
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
|
||||||
return torch.load(
|
return torch.load(
|
||||||
checkpoint_file,
|
checkpoint_file,
|
||||||
map_location=map_location,
|
map_location=map_location,
|
||||||
@ -3009,6 +3010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
token: Optional[Union[str, bool]] = None,
|
token: Optional[Union[str, bool]] = None,
|
||||||
revision: str = "main",
|
revision: str = "main",
|
||||||
use_safetensors: bool = None,
|
use_safetensors: bool = None,
|
||||||
|
weights_only: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
@ -3196,6 +3198,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
|
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
|
||||||
is not installed, it will be set to `False`.
|
is not installed, it will be set to `False`.
|
||||||
|
|
||||||
|
weights_only (`bool`, *optional*, defaults to `True`):
|
||||||
|
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
|
||||||
|
dictionaries and any types added via torch.serialization.add_safe_globals().
|
||||||
|
When set to False, we can load wrapper tensor subclass weights.
|
||||||
|
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||||
@ -3831,7 +3838,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if from_pt:
|
if from_pt:
|
||||||
if not is_sharded and state_dict is None:
|
if not is_sharded and state_dict is None:
|
||||||
# Time to load the checkpoint
|
# Time to load the checkpoint
|
||||||
state_dict = load_state_dict(resolved_archive_file)
|
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||||
|
|
||||||
# set dtype to instantiate the model under:
|
# set dtype to instantiate the model under:
|
||||||
# 1. If torch_dtype is not None, we use that dtype
|
# 1. If torch_dtype is not None, we use that dtype
|
||||||
@ -3852,7 +3859,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
elif not is_sharded:
|
elif not is_sharded:
|
||||||
torch_dtype = get_state_dict_dtype(state_dict)
|
torch_dtype = get_state_dict_dtype(state_dict)
|
||||||
else:
|
else:
|
||||||
one_state_dict = load_state_dict(resolved_archive_file[0])
|
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
|
||||||
torch_dtype = get_state_dict_dtype(one_state_dict)
|
torch_dtype = get_state_dict_dtype(one_state_dict)
|
||||||
del one_state_dict # free CPU memory
|
del one_state_dict # free CPU memory
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -4052,6 +4059,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
hf_quantizer=hf_quantizer,
|
hf_quantizer=hf_quantizer,
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
gguf_path=gguf_path,
|
gguf_path=gguf_path,
|
||||||
|
weights_only=weights_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure token embedding weights are still tied if needed
|
# make sure token embedding weights are still tied if needed
|
||||||
@ -4157,6 +4165,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
hf_quantizer=None,
|
hf_quantizer=None,
|
||||||
keep_in_fp32_modules=None,
|
keep_in_fp32_modules=None,
|
||||||
gguf_path=None,
|
gguf_path=None,
|
||||||
|
weights_only=True,
|
||||||
):
|
):
|
||||||
is_safetensors = False
|
is_safetensors = False
|
||||||
is_quantized = hf_quantizer is not None
|
is_quantized = hf_quantizer is not None
|
||||||
@ -4514,7 +4523,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
|
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
|
||||||
):
|
):
|
||||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||||
state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location)
|
state_dict = load_state_dict(
|
||||||
|
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
|
||||||
|
)
|
||||||
|
|
||||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
# matching the weights in the model.
|
# matching the weights in the model.
|
||||||
@ -4667,6 +4678,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
start_prefix="",
|
start_prefix="",
|
||||||
hf_quantizer=None,
|
hf_quantizer=None,
|
||||||
pretrained_model_name_or_path=None,
|
pretrained_model_name_or_path=None,
|
||||||
|
weights_only=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||||
@ -4687,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
|
||||||
state_dict = load_state_dict(resolved_archive_file)
|
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
|
||||||
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
|
||||||
error_msgs = _load_state_dict_into_meta_model(
|
error_msgs = _load_state_dict_into_meta_model(
|
||||||
model,
|
model,
|
||||||
|
Loading…
Reference in New Issue
Block a user