mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Poc to use safetensors (#19175)
* Poc to use safetensors * Typo * Final version * Add tests * Save with the right name! * Update tests/test_modeling_common.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Support for sharded checkpoints * Test from Hub part 1 * Test from hub part 2 * Fix regular checkpoint sharding * Bump for fixes Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
parent
dad578e4c3
commit
3e2dd7f92d
2
setup.py
2
setup.py
@ -148,6 +148,7 @@ _deps = [
|
|||||||
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
||||||
"sacrebleu>=1.4.12,<2.0.0",
|
"sacrebleu>=1.4.12,<2.0.0",
|
||||||
"sacremoses",
|
"sacremoses",
|
||||||
|
"safetensors>=0.2.1",
|
||||||
"sagemaker>=2.31.0",
|
"sagemaker>=2.31.0",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"sentencepiece>=0.1.91,!=0.1.92",
|
"sentencepiece>=0.1.91,!=0.1.92",
|
||||||
@ -300,6 +301,7 @@ extras["testing"] = (
|
|||||||
"protobuf", # Can be removed once we can unpin protobuf
|
"protobuf", # Can be removed once we can unpin protobuf
|
||||||
"sacremoses",
|
"sacremoses",
|
||||||
"rjieba",
|
"rjieba",
|
||||||
|
"safetensors",
|
||||||
)
|
)
|
||||||
+ extras["retrieval"]
|
+ extras["retrieval"]
|
||||||
+ extras["modelcreation"]
|
+ extras["modelcreation"]
|
||||||
|
@ -478,6 +478,7 @@ _import_structure = {
|
|||||||
"is_psutil_available",
|
"is_psutil_available",
|
||||||
"is_py3nvml_available",
|
"is_py3nvml_available",
|
||||||
"is_pyctcdecode_available",
|
"is_pyctcdecode_available",
|
||||||
|
"is_safetensors_available",
|
||||||
"is_scipy_available",
|
"is_scipy_available",
|
||||||
"is_sentencepiece_available",
|
"is_sentencepiece_available",
|
||||||
"is_sklearn_available",
|
"is_sklearn_available",
|
||||||
@ -3417,6 +3418,7 @@ if TYPE_CHECKING:
|
|||||||
is_psutil_available,
|
is_psutil_available,
|
||||||
is_py3nvml_available,
|
is_py3nvml_available,
|
||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
|
is_safetensors_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_sklearn_available,
|
is_sklearn_available,
|
||||||
|
@ -54,6 +54,7 @@ deps = {
|
|||||||
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
||||||
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
||||||
"sacremoses": "sacremoses",
|
"sacremoses": "sacremoses",
|
||||||
|
"safetensors": "safetensors>=0.2.1",
|
||||||
"sagemaker": "sagemaker>=2.31.0",
|
"sagemaker": "sagemaker>=2.31.0",
|
||||||
"scikit-learn": "scikit-learn",
|
"scikit-learn": "scikit-learn",
|
||||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||||
|
@ -50,6 +50,8 @@ from .pytorch_utils import ( # noqa: F401
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
FLAX_WEIGHTS_NAME,
|
FLAX_WEIGHTS_NAME,
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
TF_WEIGHTS_NAME,
|
TF_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
@ -65,6 +67,7 @@ from .utils import (
|
|||||||
is_bitsandbytes_available,
|
is_bitsandbytes_available,
|
||||||
is_offline_mode,
|
is_offline_mode,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
is_safetensors_available,
|
||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
@ -86,6 +89,10 @@ if is_accelerate_available():
|
|||||||
else:
|
else:
|
||||||
get_balanced_memory = None
|
get_balanced_memory = None
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import load_file as safe_load_file
|
||||||
|
from safetensors.torch import save_file as safe_save_file
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@ -241,7 +248,9 @@ def dtype_byte_size(dtype):
|
|||||||
return bit_size // 8
|
return bit_size // 8
|
||||||
|
|
||||||
|
|
||||||
def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"):
|
def shard_checkpoint(
|
||||||
|
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
given size.
|
given size.
|
||||||
@ -263,6 +272,8 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
|
|||||||
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
||||||
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
||||||
(like `"5MB"`).
|
(like `"5MB"`).
|
||||||
|
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
|
||||||
|
The name of the model save file.
|
||||||
"""
|
"""
|
||||||
max_shard_size = convert_file_size_to_int(max_shard_size)
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
||||||
|
|
||||||
@ -289,13 +300,16 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[
|
|||||||
|
|
||||||
# If we only have one shard, we return it
|
# If we only have one shard, we return it
|
||||||
if len(sharded_state_dicts) == 1:
|
if len(sharded_state_dicts) == 1:
|
||||||
return {WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
return {weights_name: sharded_state_dicts[0]}, None
|
||||||
|
|
||||||
# Otherwise, let's build the index
|
# Otherwise, let's build the index
|
||||||
weight_map = {}
|
weight_map = {}
|
||||||
shards = {}
|
shards = {}
|
||||||
for idx, shard in enumerate(sharded_state_dicts):
|
for idx, shard in enumerate(sharded_state_dicts):
|
||||||
shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
|
||||||
|
shard_file = shard_file.replace(
|
||||||
|
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||||
|
)
|
||||||
shards[shard_file] = shard
|
shards[shard_file] = shard
|
||||||
for key in shard.keys():
|
for key in shard.keys():
|
||||||
weight_map[key] = shard_file
|
weight_map[key] = shard_file
|
||||||
@ -367,6 +381,20 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
|||||||
"""
|
"""
|
||||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||||
"""
|
"""
|
||||||
|
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
|
||||||
|
# Check format of the archive
|
||||||
|
with safe_open(checkpoint_file, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata.get("format") not in ["pt", "tf", "flax"]:
|
||||||
|
raise OSError(
|
||||||
|
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||||
|
"you save your model with the `save_pretrained` method."
|
||||||
|
)
|
||||||
|
elif metadata["format"] != "pt":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet."
|
||||||
|
)
|
||||||
|
return safe_load_file(checkpoint_file)
|
||||||
try:
|
try:
|
||||||
return torch.load(checkpoint_file, map_location="cpu")
|
return torch.load(checkpoint_file, map_location="cpu")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1468,6 +1496,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
save_function: Callable = torch.save,
|
save_function: Callable = torch.save,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
max_shard_size: Union[int, str] = "10GB",
|
max_shard_size: Union[int, str] = "10GB",
|
||||||
|
safe_serialization: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1503,6 +1532,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||||
|
|
||||||
kwargs:
|
kwargs:
|
||||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||||
"""
|
"""
|
||||||
@ -1511,6 +1543,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
|
||||||
)
|
)
|
||||||
is_main_process = kwargs.pop("save_config")
|
is_main_process = kwargs.pop("save_config")
|
||||||
|
if safe_serialization and not is_safetensors_available():
|
||||||
|
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||||
|
|
||||||
if os.path.isfile(save_directory):
|
if os.path.isfile(save_directory):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
@ -1560,15 +1594,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
del state_dict[ignore_key]
|
del state_dict[ignore_key]
|
||||||
|
|
||||||
# Shard the model if it is too big.
|
# Shard the model if it is too big.
|
||||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||||
|
|
||||||
# Clean the folder from a previous save
|
# Clean the folder from a previous save
|
||||||
for filename in os.listdir(save_directory):
|
for filename in os.listdir(save_directory):
|
||||||
full_filename = os.path.join(save_directory, filename)
|
full_filename = os.path.join(save_directory, filename)
|
||||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||||
# in distributed settings to avoid race conditions.
|
# in distributed settings to avoid race conditions.
|
||||||
|
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||||
if (
|
if (
|
||||||
filename.startswith(WEIGHTS_NAME[:-4])
|
filename.startswith(weights_no_suffix)
|
||||||
and os.path.isfile(full_filename)
|
and os.path.isfile(full_filename)
|
||||||
and filename not in shards.keys()
|
and filename not in shards.keys()
|
||||||
and is_main_process
|
and is_main_process
|
||||||
@ -1577,12 +1613,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
save_function(shard, os.path.join(save_directory, shard_file))
|
if safe_serialization:
|
||||||
|
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||||
|
# joyfulness), but for now this enough.
|
||||||
|
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
save_function(shard, os.path.join(save_directory, shard_file))
|
||||||
|
|
||||||
if index is None:
|
if index is None:
|
||||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||||
else:
|
else:
|
||||||
save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME)
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||||
|
save_index_file = os.path.join(save_directory, save_index_file)
|
||||||
# Save the index as well
|
# Save the index as well
|
||||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
@ -1966,6 +2008,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
):
|
):
|
||||||
# Load from a Flax checkpoint in priority if from_flax
|
# Load from a Flax checkpoint in priority if from_flax
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
||||||
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||||
|
):
|
||||||
|
# Load from a safetensors checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
|
||||||
|
elif is_safetensors_available() and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
):
|
||||||
|
# Load from a sharded safetensors checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||||
@ -2013,6 +2066,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
filename = TF2_WEIGHTS_NAME
|
filename = TF2_WEIGHTS_NAME
|
||||||
elif from_flax:
|
elif from_flax:
|
||||||
filename = FLAX_WEIGHTS_NAME
|
filename = FLAX_WEIGHTS_NAME
|
||||||
|
elif is_safetensors_available():
|
||||||
|
filename = SAFE_WEIGHTS_NAME
|
||||||
else:
|
else:
|
||||||
filename = WEIGHTS_NAME
|
filename = WEIGHTS_NAME
|
||||||
|
|
||||||
@ -2033,8 +2088,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
)
|
)
|
||||||
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
|
||||||
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
||||||
# result when internet is up, the repo and revision exist, but the file does not.
|
# result when internet is up, the repo and revision exist, but the file does not.
|
||||||
|
if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
|
||||||
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
if resolved_archive_file is not None:
|
||||||
|
is_sharded = True
|
||||||
|
else:
|
||||||
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||||
|
filename = WEIGHTS_NAME
|
||||||
|
resolved_archive_file = cached_file(
|
||||||
|
pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs
|
||||||
|
)
|
||||||
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
if resolved_archive_file is None and filename == WEIGHTS_NAME:
|
||||||
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
||||||
resolved_archive_file = cached_file(
|
resolved_archive_file = cached_file(
|
||||||
|
@ -60,6 +60,7 @@ from .utils import (
|
|||||||
is_pytesseract_available,
|
is_pytesseract_available,
|
||||||
is_pytorch_quantization_available,
|
is_pytorch_quantization_available,
|
||||||
is_rjieba_available,
|
is_rjieba_available,
|
||||||
|
is_safetensors_available,
|
||||||
is_scatter_available,
|
is_scatter_available,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
@ -264,6 +265,13 @@ def require_accelerate(test_case):
|
|||||||
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_safetensors(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_rjieba(test_case):
|
def require_rjieba(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
|
||||||
|
@ -111,6 +111,7 @@ from .import_utils import (
|
|||||||
is_pytorch_quantization_available,
|
is_pytorch_quantization_available,
|
||||||
is_rjieba_available,
|
is_rjieba_available,
|
||||||
is_sacremoses_available,
|
is_sacremoses_available,
|
||||||
|
is_safetensors_available,
|
||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_scatter_available,
|
is_scatter_available,
|
||||||
@ -156,6 +157,8 @@ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
|||||||
TF_WEIGHTS_NAME = "model.ckpt"
|
TF_WEIGHTS_NAME = "model.ckpt"
|
||||||
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||||
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
||||||
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
||||||
MODEL_CARD_NAME = "modelcard.json"
|
MODEL_CARD_NAME = "modelcard.json"
|
||||||
|
@ -533,6 +533,10 @@ def is_accelerate_available():
|
|||||||
return importlib.util.find_spec("accelerate") is not None
|
return importlib.util.find_spec("accelerate") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_safetensors_available():
|
||||||
|
return importlib.util.find_spec("safetensors") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_tokenizers_available():
|
def is_tokenizers_available():
|
||||||
return importlib.util.find_spec("tokenizers") is not None
|
return importlib.util.find_spec("tokenizers") is not None
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ from transformers.testing_utils import (
|
|||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_accelerate,
|
require_accelerate,
|
||||||
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
@ -61,6 +62,8 @@ from transformers.testing_utils import (
|
|||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
@ -2980,6 +2983,57 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_save_and_load(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
# No pytorch_model.bin file, only a model.safetensors
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub(self):
|
||||||
|
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||||
|
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_save_and_load_sharded(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
|
||||||
|
# No pytorch_model.bin index file, only a model.safetensors index
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
|
# No regular weights file
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_sharded(self):
|
||||||
|
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
|
||||||
|
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
Loading…
Reference in New Issue
Block a user