Add support for weights_only flag when loading state_dict (#32481)

* Add support for `weights_only` flag when loading state_dict

Summary:
This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to
for quantized weights)

Test Plan:
tested locally with torchao weights, also need https://github.com/huggingface/transformers/pull/32306:
```
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TorchAoConfig
from torchao.utils import benchmark_model
import torchao

DEVICE_TYPE = "cuda"

def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if quantization_config is not None:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False)

    # sanity check: run the model
    input_text = "What are we having for dinner?"
    input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE)
    output = model.generate(**input_ids, max_new_tokens=1000)
    print(tokenizer.decode(output[0], skip_special_tokens=True))

    NUM_WARMUP = 1
    NUM_RUNS = 5

    if quantization_config is not None:
        torchao.quantization.utils.recommended_inductor_config_setter()

    model = torch.compile(model, mode="max-autotune")

    benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE)
    print("running benchmark")
    results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE)
    return model, results

model_id = "jerryzh168/test-model"
torchao.quantization.utils.recommended_inductor_config_setter()
bf16_model, bf16_time = init_model_and_benchmark(model_id)
print(f"bf16: {bf16_time}")
```

Reviewers:

Subscribers:

Tasks:

Tags:

* format
This commit is contained in:
Jerry Zhang 2024-10-03 08:03:42 -07:00 committed by GitHub
parent a220c5b99f
commit 15a4d24805
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -544,6 +544,7 @@ def load_state_dict(
checkpoint_file: Union[str, os.PathLike], checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False, is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None, map_location: Optional[Union[str, torch.device]] = None,
weights_only: bool = True,
): ):
""" """
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
@ -580,7 +581,7 @@ def load_state_dict(
and is_zipfile(checkpoint_file) and is_zipfile(checkpoint_file)
): ):
extra_args = {"mmap": True} extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load( return torch.load(
checkpoint_file, checkpoint_file,
map_location=map_location, map_location=map_location,
@ -3009,6 +3010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
revision: str = "main", revision: str = "main",
use_safetensors: bool = None, use_safetensors: bool = None,
weights_only: bool = True,
**kwargs, **kwargs,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
@ -3196,6 +3198,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
is not installed, it will be set to `False`. is not installed, it will be set to `False`.
weights_only (`bool`, *optional*, defaults to `True`):
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
dictionaries and any types added via torch.serialization.add_safe_globals().
When set to False, we can load wrapper tensor subclass weights.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@ -3831,7 +3838,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if from_pt: if from_pt:
if not is_sharded and state_dict is None: if not is_sharded and state_dict is None:
# Time to load the checkpoint # Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file) state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
# set dtype to instantiate the model under: # set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype # 1. If torch_dtype is not None, we use that dtype
@ -3852,7 +3859,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not is_sharded: elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict) torch_dtype = get_state_dict_dtype(state_dict)
else: else:
one_state_dict = load_state_dict(resolved_archive_file[0]) one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
torch_dtype = get_state_dict_dtype(one_state_dict) torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory del one_state_dict # free CPU memory
logger.info( logger.info(
@ -4052,6 +4059,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path, gguf_path=gguf_path,
weights_only=weights_only,
) )
# make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
@ -4157,6 +4165,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer=None, hf_quantizer=None,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
gguf_path=None, gguf_path=None,
weights_only=True,
): ):
is_safetensors = False is_safetensors = False
is_quantized = hf_quantizer is not None is_quantized = hf_quantizer is not None
@ -4514,7 +4523,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and hf_quantizer.quantization_config.quant_type == "int4_weight_only" and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
): ):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location) state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model. # matching the weights in the model.
@ -4667,6 +4678,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
start_prefix="", start_prefix="",
hf_quantizer=None, hf_quantizer=None,
pretrained_model_name_or_path=None, pretrained_model_name_or_path=None,
weights_only=True,
): ):
""" """
This is an experimental function that loads the model using ~1.x model size CPU memory This is an experimental function that loads the model using ~1.x model size CPU memory
@ -4687,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix) _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file) state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
error_msgs = _load_state_dict_into_meta_model( error_msgs = _load_state_dict_into_meta_model(
model, model,