Checkpoint sharding (#16343)

* Sharded checkpoint support

* Handle distant sharded checkpoints

* Add tests

* TODO is done

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix docstring

* Add example and format

* Address review comments

* More review comments

* End of merge

* Revert unintentional change

* VsCode what did you do?

* Style

* Changes

* Address final comments

* Quality

* Moar tests

* Move import beneath is_pt_available

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2022-03-25 11:59:25 -04:00 committed by GitHub
parent 7fa7408b26
commit b473617d63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 710 additions and 142 deletions

View File

@ -52,6 +52,7 @@ from .utils import (
USE_JAX,
USE_TF,
USE_TORCH,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
DummyObject,

View File

@ -15,11 +15,15 @@
# limitations under the License.
import inspect
import json
import os
import re
import shutil
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
@ -38,6 +42,7 @@ from .utils import (
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
ModelOutput,
@ -45,7 +50,6 @@ from .utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
cached_path,
copy_func,
has_file,
hf_bucket_url,
is_offline_mode,
@ -148,6 +152,272 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
return first_tuple[1].dtype
def convert_file_size_to_int(size: Union[int, str]):
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MB")
1048576
```
"""
if isinstance(size, int):
return size
if size.upper().endswith("GIB"):
return int(size[:-3]) * (2**30)
if size.upper().endswith("MIB"):
return int(size[:-3]) * (2**20)
if size.upper().endswith("KIB"):
return int(size[:-3]) * (2**10)
if size.upper().endswith("GB"):
return int(size[:-2]) * (10**9)
if size.upper().endswith("MB"):
return int(size[:-2]) * (10**6)
if size.upper().endswith("KB"):
return int(size[:-2]) * (10**3)
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
def dtype_byte_size(dtype):
"""
Returns the size (in bytes) occupied by one parameter of type `dtype`.
Example:
```py
>>> dtype_byte_size(torch.float32)
4
```
"""
if dtype == torch.bool:
return 1 / 8
bit_search = re.search("[^\d](\d+)$", str(dtype))
if bit_search is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
bit_size = int(bit_search.groups()[0])
return bit_size // 8
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
"""
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = []
current_block = {}
current_block_size = 0
total_size = 0
for key, weight in state_dict.items():
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block)
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
total_size += weight_size
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
local_files_only=False,
use_auth_token=None,
user_agent=None,
revision=None,
mirror=None,
):
"""
For a given model:
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
Hub
- returns the list of paths to all the shards, as well as some metadata.
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
"""
with open(index_filename, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(list(set(index["weight_map"].values())))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
for shard_filename in shard_filenames:
shard_url = hf_bucket_url(
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
)
try:
# Load from URL
cached_filename = cached_path(
shard_url,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
)
except HTTPError:
raise EnvironmentError(
f"We couldn't connect to 'https://huggingface.co/' to load {shard_filename}. You should try again "
"after checking your internet connection."
)
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
try:
return torch.load(checkpoint_file, map_location="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model_to_load, prefix=start_prefix)
return error_msgs
class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
@ -1004,6 +1274,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_hub: bool = False,
max_shard_size: Union[int, str] = "10GB",
**kwargs,
):
"""
@ -1035,6 +1306,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
</Tip>
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@ -1078,11 +1360,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if ignore_key in state_dict.keys():
del state_dict[ignore_key]
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
# Shard the model if it is too big.
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
logger.info(f"Model weights saved in {output_model_file}")
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
os.remove(full_filename)
# Save the model
for shard_file, shard in shards.items():
save_function(shard, os.path.join(save_directory, shard_file))
if index is None:
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
else:
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
@ -1293,6 +1596,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
model_kwargs = kwargs
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
sharded_metadata = None
# Load model
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@ -1309,6 +1616,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
@ -1382,29 +1693,51 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
except EntryNotFoundError:
if filename == WEIGHTS_NAME:
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights."
try:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=WEIGHTS_INDEX_NAME,
revision=revision,
mirror=mirror,
)
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
"weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
is_sharded = True
except EntryNotFoundError:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights."
)
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
"weights."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
@ -1439,29 +1772,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
mirror=mirror,
)
# load pt weights early so that we know which dtype to init the model under
if from_pt:
if state_dict is None:
try:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
except Exception as e:
try:
with open(resolved_archive_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
f"at '{resolved_archive_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
if not is_sharded:
# Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
@ -1471,7 +1803,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
torch_dtype = next(iter(state_dict.values())).dtype
if is_sharded and "dtype" in sharded_metadata:
torch_dtype = sharded_metadata["dtype"]
elif not is_sharded:
torch_dtype = next(iter(state_dict.values())).dtype
else:
one_state_dict = load_state_dict(resolved_archive_file)
torch_dtype = next(iter(one_state_dict.values())).dtype
del one_state_dict # free CPU memory
else:
raise ValueError(
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
@ -1480,8 +1819,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if low_cpu_mem_usage:
# save the keys
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
state_dict = load_state_dict(resolved_archive_file)
loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later
config.name_or_path = pretrained_model_name_or_path
@ -1534,13 +1877,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
elif from_pt:
if low_cpu_mem_usage:
cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
else:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
)
@ -1562,31 +1907,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model
@classmethod
def _load_state_dict_into_model(
cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True
def _load_pretrained_model(
cls,
model,
state_dict,
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=False,
sharded_metadata=None,
_fast_init=True,
):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if "gamma" in key:
new_key = key.replace("gamma", "weight")
if "beta" in key:
new_key = key.replace("beta", "bias")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys())
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
prefix = model.base_model_prefix
def _fix_key(key):
if "beta" in key:
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")
return key
loaded_keys = [_fix_key(key) for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
@ -1608,28 +1953,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
@ -1648,35 +1971,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for module in uninitialized_modules:
model._init_weights(module)
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
@ -1690,7 +1984,61 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"properly saved?"
)
load(model_to_load, prefix=start_prefix)
if state_dict is not None:
# Whole checkpoint
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else:
# Sharded checkpoint
# This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
error_msgs = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix_from_model:
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
model_key = f"{prefix}.{checkpoint_key}"
elif add_prefix_to_model:
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
model_key = ".".join(checkpoint_key.split(".")[1:])
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
@ -1755,7 +2103,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules
@classmethod
def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
def _load_pretrained_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
@ -1772,7 +2120,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
"""
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
@ -1806,19 +2153,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
# only now can load state_dict
state_dict = torch.load(resolved_archive_file, map_location="cpu")
# only now can load state_dict(s)
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
new_val = state_dict[k]
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
for archive_file in resolved_archive_file:
state_dict = torch.load(resolved_archive_file, map_location="cpu")
del state_dict
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
new_val = state_dict[k]
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
del state_dict
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
@ -1846,12 +2197,109 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cls._auto_class = auto_class
def push_to_hub(
self,
repo_path_or_name: Optional[str] = None,
repo_url: Optional[str] = None,
use_temp_dir: bool = False,
commit_message: str = "add model",
organization: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
max_shard_size: Union[int, str] = "10GB",
**model_card_kwargs
) -> str:
"""
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
object="model", object_class="AutoModel", object_files="model checkpoint"
)
Parameters:
repo_path_or_name (`str`, *optional*):
Can either be a repository name for your model in the Hub or a path to a local folder (in which case
the repository will have the name of that local folder). If not specified, will default to the name
given by `repo_url` and a local directory with that name will be created.
repo_url (`str`, *optional*):
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
use_temp_dir (`bool`, *optional*, defaults to `False`):
Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
current working directory. This will slow things down if you are making changes in an existing repo
since you will need to clone the repo before every push.
commit_message (`str`, *optional*, defaults to `"add model"`):
Message to commit while pushing.
organization (`str`, *optional*):
Organization in which you want to push your {object} (you must be a member of this organization).
private (`bool`, *optional*):
Whether or not the repository created should be private (requires a paying subscription).
use_auth_token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
`repo_url` is not specified.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.
</Tip>
Returns:
`str`: The url of the commit of your {object} in the given repository.
Examples:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-cased")
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
# *my-finetuned-bert* folder.
model.push_to_hub("my-finetuned-bert")
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
# *my-finetuned-bert* folder.
model.push_to_hub("my-finetuned-bert", organization="huggingface")
# Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
```
"""
if use_temp_dir:
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
if repo_url is None:
if use_auth_token is None:
use_auth_token = True
repo_name = Path(repo_path_or_name).name
repo_url = self._get_repo_url_from_name(
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
)
repo_path_or_name = tempfile.mkdtemp()
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
repo = self._create_or_get_repo(
repo_path_or_name=repo_path_or_name,
repo_url=repo_url,
organization=organization,
private=private,
use_auth_token=use_auth_token,
)
# Save the files in the cloned repo
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
# Commit and push!
url = self._push_to_hub(repo, commit_message=commit_message)
# Clean up! Clean up! Everybody everywhere!
if use_temp_dir:
shutil.rmtree(repo_path_or_name)
return url
class Conv1D(nn.Module):

View File

@ -136,6 +136,7 @@ from .import_utils import (
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack"

View File

@ -55,7 +55,7 @@ from transformers.testing_utils import (
slow,
torch_device,
)
from transformers.utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_flax_available, is_torch_fx_available
sys.path.append(str(Path(__file__).parent.parent / "utils"))
@ -90,6 +90,7 @@ if is_torch_available():
T5Config,
T5ForConditionalGeneration,
)
from transformers.modeling_utils import shard_checkpoint
if is_flax_available():
import jax.numpy as jnp
@ -2352,6 +2353,123 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = torch.nn.Sequential(
torch.nn.Linear(100, 200, bias=False), # size 80,000
torch.nn.Linear(200, 200, bias=False), # size 160,000
torch.nn.Linear(200, 100, bias=False), # size 80,000
torch.nn.Linear(100, 50, bias=False), # size 20,000
)
state_dict = model.state_dict()
with self.subTest("No shard when max size is bigger than model size"):
shards, index = shard_checkpoint(state_dict)
self.assertIsNone(index)
self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
with self.subTest("Test sharding, no weights bigger than max size"):
shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
# Split is first two layers then last two.
self.assertDictEqual(
index,
{
"metadata": {"total_size": 340000},
"weight_map": {
"0.weight": "pytorch_model-00001-of-00002.bin",
"1.weight": "pytorch_model-00001-of-00002.bin",
"2.weight": "pytorch_model-00002-of-00002.bin",
"3.weight": "pytorch_model-00002-of-00002.bin",
},
},
)
shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
self.assertDictEqual(
shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
)
with self.subTest("Test sharding with weights bigger than max size"):
shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
# Split is first layer, second layer then last 2.
self.assertDictEqual(
index,
{
"metadata": {"total_size": 340000},
"weight_map": {
"0.weight": "pytorch_model-00001-of-00003.bin",
"1.weight": "pytorch_model-00002-of-00003.bin",
"2.weight": "pytorch_model-00003-of-00003.bin",
"3.weight": "pytorch_model-00003-of-00003.bin",
},
},
)
shard1 = {"0.weight": state_dict["0.weight"]}
shard2 = {"1.weight": state_dict["1.weight"]}
shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
self.assertDictEqual(
shards,
{
"pytorch_model-00001-of-00003.bin": shard1,
"pytorch_model-00002-of-00003.bin": shard2,
"pytorch_model-00003-of-00003.bin": shard3,
},
)
def test_checkpoint_sharding_local(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".bin"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
# Check a file is bigger than max_size only when it has a single weight
for shard_file, size in shard_to_size.items():
if max_size.endswith("kiB"):
max_size_int = int(max_size[:-3]) * 2**10
else:
max_size_int = int(max_size[:-2]) * 10**3
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
state_dict = torch.load(shard_file)
self.assertEqual(len(state_dict), 1)
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin"))
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = BertModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_checkpoint_sharding_from_hub(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
# the model above is the same as the model below, just a sharded version.
ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()