mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Use huggingface_hub helper function to split state dict (#31091)
* shard saving from hf hub * index = None * fix tests * indent
This commit is contained in:
parent
1c73d85b86
commit
254b25abd9
2
setup.py
2
setup.py
@ -117,7 +117,7 @@ _deps = [
|
|||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
"GitPython<3.1.19",
|
"GitPython<3.1.19",
|
||||||
"hf-doc-builder>=0.3.0",
|
"hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub>=0.23.0,<1.0",
|
"huggingface-hub>=0.23.2,<1.0",
|
||||||
"importlib_metadata",
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
|
@ -24,7 +24,7 @@ deps = {
|
|||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
"GitPython": "GitPython<3.1.19",
|
"GitPython": "GitPython<3.1.19",
|
||||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||||
"huggingface-hub": "huggingface-hub>=0.23.0,<1.0",
|
"huggingface-hub": "huggingface-hub>=0.23.2,<1.0",
|
||||||
"importlib_metadata": "importlib_metadata",
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
|
@ -34,6 +34,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn import CrossEntropyLoss, Identity
|
from torch.nn import CrossEntropyLoss, Identity
|
||||||
@ -362,6 +363,10 @@ def shard_checkpoint(
|
|||||||
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
|
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
|
||||||
The name of the model save file.
|
The name of the model save file.
|
||||||
"""
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
|
||||||
|
"split_torch_state_dict_into_shards from huggingface_hub library"
|
||||||
|
)
|
||||||
max_shard_size = convert_file_size_to_int(max_shard_size)
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||||
|
|
||||||
sharded_state_dicts = [{}]
|
sharded_state_dicts = [{}]
|
||||||
@ -2618,7 +2623,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
else:
|
else:
|
||||||
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
|
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
|
||||||
|
|
||||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||||
|
)
|
||||||
|
# Save index if sharded
|
||||||
|
index = None
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
|
||||||
# Clean the folder from a previous save
|
# Clean the folder from a previous save
|
||||||
for filename in os.listdir(save_directory):
|
for filename in os.listdir(save_directory):
|
||||||
@ -2634,14 +2649,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if (
|
if (
|
||||||
filename.startswith(weights_no_suffix)
|
filename.startswith(weights_no_suffix)
|
||||||
and os.path.isfile(full_filename)
|
and os.path.isfile(full_filename)
|
||||||
and filename not in shards.keys()
|
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||||
and is_main_process
|
and is_main_process
|
||||||
and reg.fullmatch(filename_no_suffix) is not None
|
and reg.fullmatch(filename_no_suffix) is not None
|
||||||
):
|
):
|
||||||
os.remove(full_filename)
|
os.remove(full_filename)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, tensors in state_dict_split.filename_to_tensors.items():
|
||||||
|
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
||||||
# remake shard with onloaded parameters if necessary
|
# remake shard with onloaded parameters if necessary
|
||||||
if module_map:
|
if module_map:
|
||||||
if accelerate_version < version.parse("0.31"):
|
if accelerate_version < version.parse("0.31"):
|
||||||
@ -2680,7 +2696,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f.write(content)
|
f.write(content)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||||
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -669,7 +669,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
for max_size in ["50kB", "100kB", "200kB"]:
|
||||||
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
|
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=False)
|
||||||
|
|
||||||
# Get each shard file and its size
|
# Get each shard file and its size
|
||||||
@ -686,10 +686,7 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
# Check a file is bigger than max_size only when it has a single weight
|
# Check a file is bigger than max_size only when it has a single weight
|
||||||
for shard_file, size in shard_to_size.items():
|
for shard_file, size in shard_to_size.items():
|
||||||
if max_size.endswith("kiB"):
|
max_size_int = int(max_size[:-2]) * 10**3
|
||||||
max_size_int = int(max_size[:-3]) * 2**10
|
|
||||||
else:
|
|
||||||
max_size_int = int(max_size[:-2]) * 10**3
|
|
||||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
# the size asked for (since we count parameters)
|
# the size asked for (since we count parameters)
|
||||||
if size >= max_size_int + 50000:
|
if size >= max_size_int + 50000:
|
||||||
|
Loading…
Reference in New Issue
Block a user