mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Poc to use safetensors (#19175)
* Poc to use safetensors * Typo * Final version * Add tests * Save with the right name! * Update tests/test_modeling_common.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Support for sharded checkpoints * Test from Hub part 1 * Test from hub part 2 * Fix regular checkpoint sharding * Bump for fixes Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
parent
dad578e4c3
commit
3e2dd7f92d
2
setup.py
2
setup.py
@ -148,6 +148,7 @@ _deps = [
|
||||
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
||||
"sacrebleu>=1.4.12,<2.0.0",
|
||||
"sacremoses",
|
||||
"safetensors>=0.2.1",
|
||||
"sagemaker>=2.31.0",
|
||||
"scikit-learn",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
@ -300,6 +301,7 @@ extras["testing"] = (
|
||||
"protobuf", # Can be removed once we can unpin protobuf
|
||||
"sacremoses",
|
||||
"rjieba",
|
||||
"safetensors",
|
||||
)
|
||||
+ extras["retrieval"]
|
||||
+ extras["modelcreation"]
|
||||
|
@ -478,6 +478,7 @@ _import_structure = {
|
||||
"is_psutil_available",
|
||||
"is_py3nvml_available",
|
||||
"is_pyctcdecode_available",
|
||||
"is_safetensors_available",
|
||||
"is_scipy_available",
|
||||
"is_sentencepiece_available",
|
||||
"is_sklearn_available",
|
||||
@ -3417,6 +3418,7 @@ if TYPE_CHECKING:
|
||||
is_psutil_available,
|
||||
is_py3nvml_available,
|
||||
is_pyctcdecode_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_sklearn_available,
|
||||
|
@ -54,6 +54,7 @@ deps = {
|
||||
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
||||
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
||||
"sacremoses": "sacremoses",
|
||||
"safetensors": "safetensors>=0.2.1",
|
||||
"sagemaker": "sagemaker>=2.31.0",
|
||||
"scikit-learn": "scikit-learn",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
|
@ -50,6 +50,8 @@ from .pytorch_utils import ( # noqa: F401
|
||||
from .utils import (
|
||||
DUMMY_INPUTS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
@ -65,6 +67,7 @@ from .utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_safetensors_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@ -86,6 +89,10 @@ if is_accelerate_available():
|
||||
else:
|
||||
get_balanced_memory = None
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -241,7 +248,9 @@ def dtype_byte_size(dtype):
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
|
||||
def shard_checkpoint(
|
||||
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
|
||||
):
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
@ -263,6 +272,8 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
|
||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||
(like `"5MB"`).
|
||||
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
|
||||
The name of the model save file.
|
||||
"""
|
||||
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||
|
||||
@ -289,13 +300,16 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||
return {weights_name: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||
shard_file = shard_file.replace(
|
||||
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||
)
|
||||
shards[shard_file] = shard
|
||||
for key in shard.keys():
|
||||
weight_map[key] = shard_file
|
||||
@ -367,6 +381,20 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
||||
# Check format of the archive
|
||||
with safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata.get("format") not in ["pt", "tf", "flax"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
)
|
||||
elif metadata["format"] != "pt":
|
||||
raise NotImplementedError(
|
||||
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||
)
|
||||
return safe_load_file(checkpoint_file)
|
||||
try:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
@ -1468,6 +1496,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
save_function: Callable = torch.save,
|
||||
push_to_hub: bool = False,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
safe_serialization: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -1503,6 +1532,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
</Tip>
|
||||
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@ -1511,6 +1543,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
||||
)
|
||||
is_main_process = kwargs.pop("save_config")
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
@ -1560,15 +1594,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
del state_dict[ignore_key]
|
||||
|
||||
# Shard the model if it is too big.
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
if (
|
||||
filename.startswith(WEIGHTS_NAME[:-4])
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in shards.keys()
|
||||
and is_main_process
|
||||
@ -1577,12 +1613,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# Save the model
|
||||
for shard_file, shard in shards.items():
|
||||
save_function(shard, os.path.join(save_directory, shard_file))
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
save_function(shard, os.path.join(save_directory, shard_file))
|
||||
|
||||
if index is None:
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||
else:
|
||||
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, save_index_file)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
@ -1966,6 +2008,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
):
|
||||
# Load from a Flax checkpoint in priority if from_flax
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||
elif is_safetensors_available() and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
# Load from a sharded safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
@ -2013,6 +2066,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
filename = TF2_WEIGHTS_NAME
|
||||
elif from_flax:
|
||||
filename = FLAX_WEIGHTS_NAME
|
||||
elif is_safetensors_available():
|
||||
filename = SAFE_WEIGHTS_NAME
|
||||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
|
||||
@ -2033,8 +2088,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||
# result when internet is up, the repo and revision exist, but the file does not.
|
||||
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
else:
|
||||
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||
filename = WEIGHTS_NAME
|
||||
resolved_archive_file = cached_file(
|
||||
pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs
|
||||
)
|
||||
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||
resolved_archive_file = cached_file(
|
||||
|
@ -60,6 +60,7 @@ from .utils import (
|
||||
is_pytesseract_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_rjieba_available,
|
||||
is_safetensors_available,
|
||||
is_scatter_available,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
@ -264,6 +265,13 @@ def require_accelerate(test_case):
|
||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_safetensors(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
|
||||
|
||||
|
||||
def require_rjieba(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||
|
@ -111,6 +111,7 @@ from .import_utils import (
|
||||
is_pytorch_quantization_available,
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
is_safetensors_available,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_scatter_available,
|
||||
@ -156,6 +157,8 @@ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
CONFIG_NAME = "config.json"
|
||||
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
@ -533,6 +533,10 @@ def is_accelerate_available():
|
||||
return importlib.util.find_spec("accelerate") is not None
|
||||
|
||||
|
||||
def is_safetensors_available():
|
||||
return importlib.util.find_spec("safetensors") is not None
|
||||
|
||||
|
||||
def is_tokenizers_available():
|
||||
return importlib.util.find_spec("tokenizers") is not None
|
||||
|
||||
|
@ -53,6 +53,7 @@ from transformers.testing_utils import (
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_accelerate,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
@ -61,6 +62,8 @@ from transformers.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
@ -2980,6 +2983,57 @@ class ModelUtilsTest(TestCasePlus):
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||
)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No pytorch_model.bin file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load_sharded(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
|
||||
# No pytorch_model.bin index file, only a model.safetensors index
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
# No regular weights file
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = BertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_sharded(self):
|
||||
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
|
||||
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||
self.assertTrue(torch.allclose(p1, p2))
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
Loading…
Reference in New Issue
Block a user