Add support for weights_only flag when loading state_dict (#32481)

* Add support for `weights_only` flag when loading state_dict

Summary:
This is to enable loading a state_dict with wrapper tensor subclasses (used in torchao to
for quantized weights)

Test Plan:
tested locally with torchao weights, also need https://github.com/huggingface/transformers/pull/32306:
```
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TorchAoConfig
from torchao.utils import benchmark_model
import torchao

DEVICE_TYPE = "cuda"

def init_model_and_benchmark(model_id, torch_dtype=torch.bfloat16, quantization_config=None):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if quantization_config is not None:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, quantization_config=quantization_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=DEVICE_TYPE, torch_dtype=torch.\bfloat16, weights_only=False)

    # sanity check: run the model
    input_text = "What are we having for dinner?"
    input_ids = tokenizer(input_text, return_tensors="pt").to(DEVICE_TYPE)
    output = model.generate(**input_ids, max_new_tokens=1000)
    print(tokenizer.decode(output[0], skip_special_tokens=True))

    NUM_WARMUP = 1
    NUM_RUNS = 5

    if quantization_config is not None:
        torchao.quantization.utils.recommended_inductor_config_setter()

    model = torch.compile(model, mode="max-autotune")

    benchmark_model(model.generate, NUM_WARMUP, kwargs=input_ids, device_type=DEVICE_TYPE)
    print("running benchmark")
    results = benchmark_model(model.generate, NUM_RUNS, kwargs=input_ids, device_type=DEVICE_TYPE)
    return model, results

model_id = "jerryzh168/test-model"
torchao.quantization.utils.recommended_inductor_config_setter()
bf16_model, bf16_time = init_model_and_benchmark(model_id)
print(f"bf16: {bf16_time}")
```

Reviewers:

Subscribers:

Tasks:

Tags:

* format
This commit is contained in:
Jerry Zhang 2024-10-03 08:03:42 -07:00 committed by GitHub
parent a220c5b99f
commit 15a4d24805
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -544,6 +544,7 @@ def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
weights_only: bool = True,
):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
@ -580,7 +581,7 @@ def load_state_dict(
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
return torch.load(
checkpoint_file,
map_location=map_location,
@ -3009,6 +3010,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
weights_only: bool = True,
**kwargs,
) -> "PreTrainedModel":
r"""
@ -3196,6 +3198,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
is not installed, it will be set to `False`.
weights_only (`bool`, *optional*, defaults to `True`):
Indicates whether unpickler should be restricted to loading only tensors, primitive types,
dictionaries and any types added via torch.serialization.add_safe_globals().
When set to False, we can load wrapper tensor subclass weights.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@ -3831,7 +3838,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if from_pt:
if not is_sharded and state_dict is None:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
@ -3852,7 +3859,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif not is_sharded:
torch_dtype = get_state_dict_dtype(state_dict)
else:
one_state_dict = load_state_dict(resolved_archive_file[0])
one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only)
torch_dtype = get_state_dict_dtype(one_state_dict)
del one_state_dict # free CPU memory
logger.info(
@ -4052,6 +4059,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
# make sure token embedding weights are still tied if needed
@ -4157,6 +4165,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer=None,
keep_in_fp32_modules=None,
gguf_path=None,
weights_only=True,
):
is_safetensors = False
is_quantized = hf_quantizer is not None
@ -4514,7 +4523,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location)
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
@ -4667,6 +4678,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
start_prefix="",
hf_quantizer=None,
pretrained_model_name_or_path=None,
weights_only=True,
):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
@ -4687,7 +4699,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
_move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
state_dict = load_state_dict(resolved_archive_file)
state_dict = load_state_dict(resolved_archive_file, weights_only=weights_only)
expected_keys = loaded_state_dict_keys # plug for missing expected_keys. TODO: replace with proper keys
error_msgs = _load_state_dict_into_meta_model(
model,