mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Checkpoint sharding (#16343)
* Sharded checkpoint support * Handle distant sharded checkpoints * Add tests * TODO is done * Apply suggestions from code review Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Fix docstring * Add example and format * Address review comments * More review comments * End of merge * Revert unintentional change * VsCode what did you do? * Style * Changes * Address final comments * Quality * Moar tests * Move import beneath is_pt_available Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
parent
7fa7408b26
commit
b473617d63
@ -52,6 +52,7 @@ from .utils import (
|
||||
USE_JAX,
|
||||
USE_TF,
|
||||
USE_TORCH,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
DummyObject,
|
||||
|
@ -15,11 +15,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -38,6 +42,7 @@ from .utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
@ -45,7 +50,6 @@ from .utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
@ -148,6 +152,272 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
|
||||
return first_tuple[1].dtype
|
||||
|
||||
|
||||
def convert_file_size_to_int(size: Union[int, str]):
|
||||
"""
|
||||
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
||||
|
||||
Args:
|
||||
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> convert_file_size_to_int("1MB")
|
||||
1048576
|
||||
```
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
return size
|
||||
if size.upper().endswith("GIB"):
|
||||
return int(size[:-3]) * (2**30)
|
||||
if size.upper().endswith("MIB"):
|
||||
return int(size[:-3]) * (2**20)
|
||||
if size.upper().endswith("KIB"):
|
||||
return int(size[:-3]) * (2**10)
|
||||
if size.upper().endswith("GB"):
|
||||
return int(size[:-2]) * (10**9)
|
||||
if size.upper().endswith("MB"):
|
||||
return int(size[:-2]) * (10**6)
|
||||
if size.upper().endswith("KB"):
|
||||
return int(size[:-2]) * (10**3)
|
||||
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
|
||||
|
||||
|
||||
def dtype_byte_size(dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(torch.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search("[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
|
||||
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
|
||||
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
|
||||
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
|
||||
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will
|
||||
have a size greater than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||
(like `"5MB"`).
|
||||
"""
|
||||
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||
|
||||
sharded_state_dicts = []
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
total_size = 0
|
||||
|
||||
for key, weight in state_dict.items():
|
||||
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
if current_block_size + weight_size > max_shard_size:
|
||||
sharded_state_dicts.append(current_block)
|
||||
current_block = {}
|
||||
current_block_size = 0
|
||||
|
||||
current_block[key] = weight
|
||||
current_block_size += weight_size
|
||||
total_size += weight_size
|
||||
|
||||
# Add the last block
|
||||
sharded_state_dicts.append(current_block)
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||
shards[shard_file] = shard
|
||||
for key in shard.keys():
|
||||
weight_map[key] = shard_file
|
||||
|
||||
# Add the metadata
|
||||
metadata = {"total_size": total_size}
|
||||
index = {"metadata": metadata, "weight_map": weight_map}
|
||||
return shards, index
|
||||
|
||||
|
||||
def get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_filename,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
resume_download=False,
|
||||
local_files_only=False,
|
||||
use_auth_token=None,
|
||||
user_agent=None,
|
||||
revision=None,
|
||||
mirror=None,
|
||||
):
|
||||
"""
|
||||
For a given model:
|
||||
|
||||
- download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
|
||||
Hub
|
||||
- returns the list of paths to all the shards, as well as some metadata.
|
||||
|
||||
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
|
||||
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
|
||||
"""
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
shard_filenames = sorted(list(set(index["weight_map"].values())))
|
||||
sharded_metadata = index["metadata"]
|
||||
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||
|
||||
# First, let's deal with local folder.
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
||||
return shard_filenames, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
cached_filenames = []
|
||||
for shard_filename in shard_filenames:
|
||||
shard_url = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL
|
||||
cached_filename = cached_path(
|
||||
shard_url,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here.
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
except HTTPError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to 'https://huggingface.co/' to load {shard_filename}. You should try again "
|
||||
"after checking your internet connection."
|
||||
)
|
||||
|
||||
cached_filenames.append(cached_filename)
|
||||
|
||||
return cached_filenames, sharded_metadata
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
||||
"model. Make sure you have saved the model properly."
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
)
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
new_key = key.replace("gamma", "weight")
|
||||
if "beta" in key:
|
||||
new_key = key.replace("beta", "bias")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
# because zero3 puts placeholders in model params, this context
|
||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||
# the state dict and then re-partitions them again
|
||||
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
module._load_from_state_dict(*args)
|
||||
else:
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
||||
class ModuleUtilsMixin:
|
||||
"""
|
||||
A few utilities for `torch.nn.Modules`, to be used as a mixin.
|
||||
@ -1004,6 +1274,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
state_dict: Optional[dict] = None,
|
||||
save_function: Callable = torch.save,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -1035,6 +1306,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
</Tip>
|
||||
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||
which will be bigger than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@ -1078,11 +1360,32 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if ignore_key in state_dict.keys():
|
||||
del state_dict[ignore_key]
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
save_function(state_dict, output_model_file)
|
||||
# Shard the model if it is too big.
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename):
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save the model
|
||||
for shard_file, shard in shards.items():
|
||||
save_function(shard, os.path.join(save_directory, shard_file))
|
||||
|
||||
if index is None:
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||
else:
|
||||
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
@ -1293,6 +1596,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
is_sharded = False
|
||||
sharded_metadata = None
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
@ -1309,6 +1616,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
|
||||
# Load from a sharded PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
@ -1382,29 +1693,51 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
if filename == WEIGHTS_NAME:
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||
"weights."
|
||||
try:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_INDEX_NAME,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
|
||||
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
is_sharded = True
|
||||
except EntryNotFoundError:
|
||||
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
|
||||
# message.
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
|
||||
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
@ -1439,29 +1772,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
resolved_archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
# load pt weights early so that we know which dtype to init the model under
|
||||
if from_pt:
|
||||
if state_dict is None:
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(resolved_archive_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' "
|
||||
f"at '{resolved_archive_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
)
|
||||
|
||||
if not is_sharded:
|
||||
# Time to load the checkpoint
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
@ -1471,7 +1803,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if torch_dtype == "auto":
|
||||
torch_dtype = next(iter(state_dict.values())).dtype
|
||||
if is_sharded and "dtype" in sharded_metadata:
|
||||
torch_dtype = sharded_metadata["dtype"]
|
||||
elif not is_sharded:
|
||||
torch_dtype = next(iter(state_dict.values())).dtype
|
||||
else:
|
||||
one_state_dict = load_state_dict(resolved_archive_file)
|
||||
torch_dtype = next(iter(one_state_dict.values())).dtype
|
||||
del one_state_dict # free CPU memory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`torch_dtype` can be either a `torch.dtype` or `auto`, but received {torch_dtype}"
|
||||
@ -1480,8 +1819,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# save the keys
|
||||
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
||||
del state_dict # free CPU memory - will reload again later
|
||||
if is_sharded:
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
else:
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
||||
del state_dict # free CPU memory - will reload again later
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
@ -1534,13 +1877,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif from_pt:
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
cls._load_state_dict_into_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
||||
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file)
|
||||
else:
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
sharded_metadata=sharded_metadata,
|
||||
_fast_init=_fast_init,
|
||||
)
|
||||
|
||||
@ -1562,31 +1907,31 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _load_state_dict_into_model(
|
||||
cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
model,
|
||||
state_dict,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=False,
|
||||
sharded_metadata=None,
|
||||
_fast_init=True,
|
||||
):
|
||||
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if "gamma" in key:
|
||||
new_key = key.replace("gamma", "weight")
|
||||
if "beta" in key:
|
||||
new_key = key.replace("beta", "bias")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = model.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
loaded_keys = list(state_dict.keys())
|
||||
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
|
||||
prefix = model.base_model_prefix
|
||||
|
||||
def _fix_key(key):
|
||||
if "beta" in key:
|
||||
return key.replace("beta", "bias")
|
||||
if "gamma" in key:
|
||||
return key.replace("gamma", "weight")
|
||||
return key
|
||||
|
||||
loaded_keys = [_fix_key(key) for key in loaded_keys]
|
||||
|
||||
if len(prefix) > 0:
|
||||
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
|
||||
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
|
||||
@ -1608,28 +1953,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
if remove_prefix_from_model:
|
||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||
model_key = f"{prefix}.{checkpoint_key}"
|
||||
elif add_prefix_to_model:
|
||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if cls._keys_to_ignore_on_load_missing is not None:
|
||||
@ -1648,35 +1971,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
for module in uninitialized_modules:
|
||||
model._init_weights(module)
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
# because zero3 puts placeholders in model params, this context
|
||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||
# the state dict and then re-partitions them again
|
||||
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
module._load_from_state_dict(*args)
|
||||
else:
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
start_prefix = ""
|
||||
model_to_load = model
|
||||
@ -1690,7 +1984,61 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"properly saved?"
|
||||
)
|
||||
|
||||
load(model_to_load, prefix=start_prefix)
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
if remove_prefix_from_model:
|
||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||
model_key = f"{prefix}.{checkpoint_key}"
|
||||
elif add_prefix_to_model:
|
||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
|
||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||
else:
|
||||
# Sharded checkpoint
|
||||
# This should always be a list but, just to be sure.
|
||||
if not isinstance(resolved_archive_file, list):
|
||||
resolved_archive_file = [resolved_archive_file]
|
||||
|
||||
error_msgs = []
|
||||
for shard_file in resolved_archive_file:
|
||||
state_dict = load_state_dict(shard_file)
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
if remove_prefix_from_model:
|
||||
# The model key starts with `prefix` but `checkpoint_key` doesn't so we add it.
|
||||
model_key = f"{prefix}.{checkpoint_key}"
|
||||
elif add_prefix_to_model:
|
||||
# The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it.
|
||||
model_key = ".".join(checkpoint_key.split(".")[1:])
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
|
||||
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
@ -1755,7 +2103,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return retrieved_modules
|
||||
|
||||
@classmethod
|
||||
def _load_state_dict_into_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
|
||||
def _load_pretrained_model_low_mem(cls, model, loaded_state_dict_keys, resolved_archive_file):
|
||||
"""
|
||||
This is an experimental function that loads the model using ~1.x model size CPU memory
|
||||
|
||||
@ -1772,7 +2120,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
|
||||
"""
|
||||
|
||||
require_version_core("torch>=1.9")
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
|
||||
@ -1806,19 +2153,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
new_val = new_val.to("meta")
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
# only now can load state_dict
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
# only now can load state_dict(s)
|
||||
if not isinstance(resolved_archive_file, list):
|
||||
resolved_archive_file = [resolved_archive_file]
|
||||
|
||||
# materialize state_dict entries one by one on CPU
|
||||
for k in loaded_state_dict_keys:
|
||||
submodule, param_name = find_submodule_and_param_name(model, k)
|
||||
if submodule is not None:
|
||||
new_val = state_dict[k]
|
||||
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
||||
new_val = torch.nn.Parameter(new_val)
|
||||
setattr(submodule, param_name, new_val)
|
||||
for archive_file in resolved_archive_file:
|
||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||
|
||||
del state_dict
|
||||
# materialize state_dict entries one by one on CPU
|
||||
for k in loaded_state_dict_keys:
|
||||
submodule, param_name = find_submodule_and_param_name(model, k)
|
||||
if submodule is not None:
|
||||
new_val = state_dict[k]
|
||||
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
|
||||
new_val = torch.nn.Parameter(new_val)
|
||||
setattr(submodule, param_name, new_val)
|
||||
|
||||
del state_dict
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||
@ -1846,12 +2197,109 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_path_or_name: Optional[str] = None,
|
||||
repo_url: Optional[str] = None,
|
||||
use_temp_dir: bool = False,
|
||||
commit_message: str = "add model",
|
||||
organization: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
**model_card_kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
|
||||
|
||||
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
|
||||
PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
|
||||
PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
|
||||
object="model", object_class="AutoModel", object_files="model checkpoint"
|
||||
)
|
||||
Parameters:
|
||||
repo_path_or_name (`str`, *optional*):
|
||||
Can either be a repository name for your model in the Hub or a path to a local folder (in which case
|
||||
the repository will have the name of that local folder). If not specified, will default to the name
|
||||
given by `repo_url` and a local directory with that name will be created.
|
||||
repo_url (`str`, *optional*):
|
||||
Specify this in case you want to push to an existing repository in the hub. If unspecified, a new
|
||||
repository will be created in your namespace (unless you specify an `organization`) with `repo_name`.
|
||||
use_temp_dir (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to clone the distant repo in a temporary directory or in `repo_path_or_name` inside the
|
||||
current working directory. This will slow things down if you are making changes in an existing repo
|
||||
since you will need to clone the repo before every push.
|
||||
commit_message (`str`, *optional*, defaults to `"add model"`):
|
||||
Message to commit while pushing.
|
||||
organization (`str`, *optional*):
|
||||
Organization in which you want to push your {object} (you must be a member of this organization).
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private (requires a paying subscription).
|
||||
use_auth_token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`). Will default to `True` if
|
||||
`repo_url` is not specified.
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
||||
which will be bigger than `max_shard_size`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`str`: The url of the commit of your {object} in the given repository.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained("bert-base-cased")
|
||||
|
||||
# Push the model to your namespace with the name "my-finetuned-bert" and have a local clone in the
|
||||
# *my-finetuned-bert* folder.
|
||||
model.push_to_hub("my-finetuned-bert")
|
||||
|
||||
# Push the model to your namespace with the name "my-finetuned-bert" with no local clone.
|
||||
model.push_to_hub("my-finetuned-bert", use_temp_dir=True)
|
||||
|
||||
# Push the model to an organization with the name "my-finetuned-bert" and have a local clone in the
|
||||
# *my-finetuned-bert* folder.
|
||||
model.push_to_hub("my-finetuned-bert", organization="huggingface")
|
||||
|
||||
# Make a change to an existing repo that has been cloned locally in *my-finetuned-bert*.
|
||||
model.push_to_hub("my-finetuned-bert", repo_url="https://huggingface.co/sgugger/my-finetuned-bert")
|
||||
```
|
||||
"""
|
||||
if use_temp_dir:
|
||||
# Make sure we use the right `repo_name` for the `repo_url` before replacing it.
|
||||
if repo_url is None:
|
||||
if use_auth_token is None:
|
||||
use_auth_token = True
|
||||
repo_name = Path(repo_path_or_name).name
|
||||
repo_url = self._get_repo_url_from_name(
|
||||
repo_name, organization=organization, private=private, use_auth_token=use_auth_token
|
||||
)
|
||||
repo_path_or_name = tempfile.mkdtemp()
|
||||
|
||||
# Create or clone the repo. If the repo is already cloned, this just retrieves the path to the repo.
|
||||
repo = self._create_or_get_repo(
|
||||
repo_path_or_name=repo_path_or_name,
|
||||
repo_url=repo_url,
|
||||
organization=organization,
|
||||
private=private,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
# Save the files in the cloned repo
|
||||
self.save_pretrained(repo_path_or_name, max_shard_size=max_shard_size)
|
||||
|
||||
# Commit and push!
|
||||
url = self._push_to_hub(repo, commit_message=commit_message)
|
||||
|
||||
# Clean up! Clean up! Everybody everywhere!
|
||||
if use_temp_dir:
|
||||
shutil.rmtree(repo_path_or_name)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
|
@ -136,6 +136,7 @@ from .import_utils import (
|
||||
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||
|
@ -55,7 +55,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@ -90,6 +90,7 @@ if is_torch_available():
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
@ -2352,6 +2353,123 @@ class ModelUtilsTest(TestCasePlus):
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(100, 200, bias=False), # size 80,000
|
||||
torch.nn.Linear(200, 200, bias=False), # size 160,000
|
||||
torch.nn.Linear(200, 100, bias=False), # size 80,000
|
||||
torch.nn.Linear(100, 50, bias=False), # size 20,000
|
||||
)
|
||||
state_dict = model.state_dict()
|
||||
|
||||
with self.subTest("No shard when max size is bigger than model size"):
|
||||
shards, index = shard_checkpoint(state_dict)
|
||||
self.assertIsNone(index)
|
||||
self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
|
||||
|
||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
|
||||
# Split is first two layers then last two.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"0.weight": "pytorch_model-00001-of-00002.bin",
|
||||
"1.weight": "pytorch_model-00001-of-00002.bin",
|
||||
"2.weight": "pytorch_model-00002-of-00002.bin",
|
||||
"3.weight": "pytorch_model-00002-of-00002.bin",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
|
||||
shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||
self.assertDictEqual(
|
||||
shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
|
||||
)
|
||||
|
||||
with self.subTest("Test sharding with weights bigger than max size"):
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
|
||||
# Split is first layer, second layer then last 2.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"0.weight": "pytorch_model-00001-of-00003.bin",
|
||||
"1.weight": "pytorch_model-00002-of-00003.bin",
|
||||
"2.weight": "pytorch_model-00003-of-00003.bin",
|
||||
"3.weight": "pytorch_model-00003-of-00003.bin",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = {"0.weight": state_dict["0.weight"]}
|
||||
shard2 = {"1.weight": state_dict["1.weight"]}
|
||||
shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||
self.assertDictEqual(
|
||||
shards,
|
||||
{
|
||||
"pytorch_model-00001-of-00003.bin": shard1,
|
||||
"pytorch_model-00002-of-00003.bin": shard2,
|
||||
"pytorch_model-00003-of-00003.bin": shard3,
|
||||
},
|
||||
)
|
||||
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".bin"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
state_dict = torch.load(shard_file)
|
||||
self.assertEqual(len(state_dict), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = set(f for f in os.listdir(tmp_dir) if f.endswith(".bin"))
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
# the model above is the same as the model below, just a sharded version.
|
||||
ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
|
Loading…
Reference in New Issue
Block a user