Poc to use safetensors (#19175)

* Poc to use safetensors

* Typo

* Final version

* Add tests

* Save with the right name!

* Update tests/test_modeling_common.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Support for sharded checkpoints

* Test from Hub part 1

* Test from hub part 2

* Fix regular checkpoint sharding

* Bump for fixes

Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
Sylvain Gugger 2022-09-30 10:58:04 -04:00 committed by GitHub
parent dad578e4c3
commit 3e2dd7f92d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 150 additions and 8 deletions

View File

@ -148,6 +148,7 @@ _deps = [
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"sacrebleu>=1.4.12,<2.0.0", "sacrebleu>=1.4.12,<2.0.0",
"sacremoses", "sacremoses",
"safetensors>=0.2.1",
"sagemaker>=2.31.0", "sagemaker>=2.31.0",
"scikit-learn", "scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
@ -300,6 +301,7 @@ extras["testing"] = (
"protobuf", # Can be removed once we can unpin protobuf "protobuf", # Can be removed once we can unpin protobuf
"sacremoses", "sacremoses",
"rjieba", "rjieba",
"safetensors",
) )
+ extras["retrieval"] + extras["retrieval"]
+ extras["modelcreation"] + extras["modelcreation"]

View File

@ -478,6 +478,7 @@ _import_structure = {
"is_psutil_available", "is_psutil_available",
"is_py3nvml_available", "is_py3nvml_available",
"is_pyctcdecode_available", "is_pyctcdecode_available",
"is_safetensors_available",
"is_scipy_available", "is_scipy_available",
"is_sentencepiece_available", "is_sentencepiece_available",
"is_sklearn_available", "is_sklearn_available",
@ -3417,6 +3418,7 @@ if TYPE_CHECKING:
is_psutil_available, is_psutil_available,
is_py3nvml_available, is_py3nvml_available,
is_pyctcdecode_available, is_pyctcdecode_available,
is_safetensors_available,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
is_sklearn_available, is_sklearn_available,

View File

@ -54,6 +54,7 @@ deps = {
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses", "sacremoses": "sacremoses",
"safetensors": "safetensors>=0.2.1",
"sagemaker": "sagemaker>=2.31.0", "sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn", "scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",

View File

@ -50,6 +50,8 @@ from .pytorch_utils import ( # noqa: F401
from .utils import ( from .utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_NAME, TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME, TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
@ -65,6 +67,7 @@ from .utils import (
is_bitsandbytes_available, is_bitsandbytes_available,
is_offline_mode, is_offline_mode,
is_remote_url, is_remote_url,
is_safetensors_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@ -86,6 +89,10 @@ if is_accelerate_available():
else: else:
get_balanced_memory = None get_balanced_memory = None
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -241,7 +248,9 @@ def dtype_byte_size(dtype):
return bit_size // 8 return bit_size // 8
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"): def shard_checkpoint(
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
):
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. given size.
@ -263,6 +272,8 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`). (like `"5MB"`).
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
""" """
max_shard_size = convert_file_size_to_int(max_shard_size) max_shard_size = convert_file_size_to_int(max_shard_size)
@ -289,13 +300,16 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
# If we only have one shard, we return it # If we only have one shard, we return it
if len(sharded_state_dicts) == 1: if len(sharded_state_dicts) == 1:
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index # Otherwise, let's build the index
weight_map = {} weight_map = {}
shards = {} shards = {}
for idx, shard in enumerate(sharded_state_dicts): for idx, shard in enumerate(sharded_state_dicts):
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard shards[shard_file] = shard
for key in shard.keys(): for key in shard.keys():
weight_map[key] = shard_file weight_map[key] = shard_file
@ -367,6 +381,20 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
""" """
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
""" """
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
elif metadata["format"] != "pt":
raise NotImplementedError(
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
)
return safe_load_file(checkpoint_file)
try: try:
return torch.load(checkpoint_file, map_location="cpu") return torch.load(checkpoint_file, map_location="cpu")
except Exception as e: except Exception as e:
@ -1468,6 +1496,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
save_function: Callable = torch.save, save_function: Callable = torch.save,
push_to_hub: bool = False, push_to_hub: bool = False,
max_shard_size: Union[int, str] = "10GB", max_shard_size: Union[int, str] = "10GB",
safe_serialization: bool = False,
**kwargs, **kwargs,
): ):
""" """
@ -1503,6 +1532,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
</Tip> </Tip>
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
kwargs: kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
""" """
@ -1511,6 +1543,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
) )
is_main_process = kwargs.pop("save_config") is_main_process = kwargs.pop("save_config")
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
@ -1560,15 +1594,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
del state_dict[ignore_key] del state_dict[ignore_key]
# Shard the model if it is too big. # Shard the model if it is too big.
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
# Clean the folder from a previous save # Clean the folder from a previous save
for filename in os.listdir(save_directory): for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename) full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions. # in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if ( if (
filename.startswith(WEIGHTS_NAME[:-4]) filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename) and os.path.isfile(full_filename)
and filename not in shards.keys() and filename not in shards.keys()
and is_main_process and is_main_process
@ -1577,12 +1613,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save the model # Save the model
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
else:
save_function(shard, os.path.join(save_directory, shard_file)) save_function(shard, os.path.join(save_directory, shard_file))
if index is None: if index is None:
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
else: else:
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME) save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file)
# Save the index as well # Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f: with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@ -1966,6 +2008,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
): ):
# Load from a Flax checkpoint in priority if from_flax # Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
):
# Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
elif is_safetensors_available() and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
@ -2013,6 +2066,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename = TF2_WEIGHTS_NAME filename = TF2_WEIGHTS_NAME
elif from_flax: elif from_flax:
filename = FLAX_WEIGHTS_NAME filename = FLAX_WEIGHTS_NAME
elif is_safetensors_available():
filename = SAFE_WEIGHTS_NAME
else: else:
filename = WEIGHTS_NAME filename = WEIGHTS_NAME
@ -2033,8 +2088,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
# result when internet is up, the repo and revision exist, but the file does not. # result when internet is up, the repo and revision exist, but the file does not.
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = WEIGHTS_NAME
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs
)
if resolved_archive_file is None and filename == WEIGHTS_NAME: if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file( resolved_archive_file = cached_file(

View File

@ -60,6 +60,7 @@ from .utils import (
is_pytesseract_available, is_pytesseract_available,
is_pytorch_quantization_available, is_pytorch_quantization_available,
is_rjieba_available, is_rjieba_available,
is_safetensors_available,
is_scatter_available, is_scatter_available,
is_scipy_available, is_scipy_available,
is_sentencepiece_available, is_sentencepiece_available,
@ -264,6 +265,13 @@ def require_accelerate(test_case):
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
def require_safetensors(test_case):
"""
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
"""
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
def require_rjieba(test_case): def require_rjieba(test_case):
""" """
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.

View File

@ -111,6 +111,7 @@ from .import_utils import (
is_pytorch_quantization_available, is_pytorch_quantization_available,
is_rjieba_available, is_rjieba_available,
is_sacremoses_available, is_sacremoses_available,
is_safetensors_available,
is_sagemaker_dp_enabled, is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled, is_sagemaker_mp_enabled,
is_scatter_available, is_scatter_available,
@ -156,6 +157,8 @@ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt" TF_WEIGHTS_NAME = "model.ckpt"
FLAX_WEIGHTS_NAME = "flax_model.msgpack" FLAX_WEIGHTS_NAME = "flax_model.msgpack"
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
SAFE_WEIGHTS_NAME = "model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
MODEL_CARD_NAME = "modelcard.json" MODEL_CARD_NAME = "modelcard.json"

View File

@ -533,6 +533,10 @@ def is_accelerate_available():
return importlib.util.find_spec("accelerate") is not None return importlib.util.find_spec("accelerate") is not None
def is_safetensors_available():
return importlib.util.find_spec("safetensors") is not None
def is_tokenizers_available(): def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None return importlib.util.find_spec("tokenizers") is not None

View File

@ -53,6 +53,7 @@ from transformers.testing_utils import (
is_pt_tf_cross_test, is_pt_tf_cross_test,
is_staging_test, is_staging_test,
require_accelerate, require_accelerate,
require_safetensors,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
@ -61,6 +62,8 @@ from transformers.testing_utils import (
torch_device, torch_device,
) )
from transformers.utils import ( from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
is_accelerate_available, is_accelerate_available,
@ -2980,6 +2983,57 @@ class ModelUtilsTest(TestCasePlus):
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
) )
@require_safetensors
def test_safetensors_save_and_load(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
# No pytorch_model.bin file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
new_model = BertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_safetensors
def test_safetensors_load_from_hub(self):
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Check models are equal
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_safetensors
def test_safetensors_save_and_load_sharded(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
# No pytorch_model.bin index file, only a model.safetensors index
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
# No regular weights file
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
new_model = BertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_safetensors
def test_safetensors_load_from_hub_sharded(self):
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
# Check models are equal
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
self.assertTrue(torch.allclose(p1, p2))
@require_torch @require_torch
@is_staging_test @is_staging_test