Support sharded safetensors in TF (#29350)

* Initial commit (still lots of unfinished bits)

* (Still untested) add safetensors sharding to save_pretrained

* Fix savetensors saving, update default shard size to match PT

* Add proper loading of TF-format safetensors

* Revert default size in case that changes things

* Fix incorrect index name

* Update loading priority

* Update tests

* Make the tests a little more stringent

* Expand tests

* Add sharded cross-test

* Fix argument name

* One more test fix

* Adding mlx to the list of allowed formats

* Remove irrelevant block for safetensors

* Refactor warning logging into a separate function

* Remove unused skip_logger_warnings arg

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Move function def

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt 2024-03-20 14:22:35 +00:00 committed by GitHub
parent 870bbb4c6b
commit 11ef35e828
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 324 additions and 86 deletions

View File

@ -21,10 +21,24 @@ import re
import numpy
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
from .utils import (
ExplicitEnum,
expand_dims,
is_numpy_array,
is_safetensors_available,
is_torch_tensor,
logging,
reshape,
squeeze,
tensor_size,
)
from .utils import transpose as transpose_func
if is_safetensors_available():
from safetensors import safe_open
logger = logging.get_logger(__name__)
@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model(
)
def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {class_name} from a PyTorch model trained on another task or with another architecture"
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {class_name} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {class_name} for predictions without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {class_name} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
def load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model(
_prefix=None,
tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
skip_logger_warnings=False,
):
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model(
if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if not skip_logger_warnings:
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
}
return tf_model, loading_info
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
return tf_model
def load_sharded_pytorch_safetensors_in_tf2_model(
tf_model,
safetensors_shards,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
):
all_loading_infos = []
for shard in safetensors_shards:
with safe_open(shard, framework="tf") as safetensors_archive:
tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
tf_model,
safetensors_archive,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=True,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
ignore_mismatched_sizes=ignore_mismatched_sizes,
skip_logger_warnings=True, # We will emit merged warnings at the end
)
all_loading_infos.append(loading_info)
# Now we just need to merge the loading info
# Keys are missing only if they're missing in *every* shard
missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
# Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
if output_loading_info:
loading_info = {

View File

@ -647,7 +647,7 @@ def strip_model_name_and_prefix(name, _prefix=None):
return name
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
@ -695,13 +695,16 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for weight in shard:
weight_name = weight.name
@ -782,7 +785,8 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
"""
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
Handles missing keys and unexpected keys.
Args:
model (`keras.models.Model`): Model in which the weights are loaded
@ -868,6 +872,61 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
)
def load_tf_sharded_weights_from_safetensors(
model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
):
"""
This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
shapes.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`keras.models.Model`): The model in which to load the checkpoint.
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
Whether or not to ignore the mismatch between the sizes
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
Returns:
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
mismatched layers.
"""
# Load the index
unexpected_keys = set()
all_missing_keys = []
mismatched_keys = set()
for shard_file in shard_files:
missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
model,
shard_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=_prefix,
)
all_missing_keys.append(set(missing_layers))
unexpected_keys.update(unexpected_layers)
mismatched_keys.update(mismatched_layers)
gc.collect()
missing_keys = set.intersection(*all_missing_keys)
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
return missing_keys, unexpected_keys, mismatched_keys
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
"""
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
@ -2303,7 +2362,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
version=1,
push_to_hub=False,
signatures=None,
max_shard_size: Union[int, str] = "10GB",
max_shard_size: Union[int, str] = "5GB",
create_pr: bool = False,
safe_serialization: bool = False,
token: Optional[Union[str, bool]] = None,
@ -2415,7 +2474,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
output_model_file = os.path.join(save_directory, weights_name)
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
@ -2438,7 +2497,8 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
self.save_weights(output_model_file)
logger.info(f"Model weights saved in {output_model_file}")
else:
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as index_file:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@ -2449,19 +2509,25 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
f"index located at {save_index_file}."
)
for shard_file, shard in shards.items():
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
layers = []
for layer in sorted(shard, key=lambda x: x.name):
if "model." in layer.name or len(layer.name.split("/")) == 1:
layer_name = layer.name
else:
layer_name = "/".join(layer.name.split("/")[1:])
param_dset = shard_file.create_dataset(
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
)
param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8"))
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
if safe_serialization:
shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
safe_save_file(
shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
)
else:
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
layers = []
for layer in sorted(shard, key=lambda x: x.name):
if "model." in layer.name or len(layer.name.split("/")) == 1:
layer_name = layer.name
else:
layer_name = "/".join(layer.name.split("/")[1:])
param_dset = shard_file.create_dataset(
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
)
param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8"))
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
if push_to_hub:
self._upload_modified_files(
@ -2698,6 +2764,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
):
# Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
@ -2705,17 +2777,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
# Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
# At this stage we don't have a weight file so we will raise an error.
elif use_safetensors:
raise EnvironmentError(
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. "
f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
f"set `use_safetensors=True`."
)
@ -2723,13 +2789,13 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
):
raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
"weights."
)
else:
raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
elif os.path.isfile(pretrained_model_name_or_path):
@ -2801,9 +2867,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
}
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
@ -2841,7 +2904,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, _ = get_checkpoint_shard_files(
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
@ -2859,7 +2922,16 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
if filename == SAFE_WEIGHTS_NAME:
with safe_open(resolved_archive_file, framework="tf") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method."
)
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
elif filename == SAFE_WEIGHTS_INDEX_NAME:
with safe_open(resolved_archive_file[0], framework="tf") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method."
@ -2902,11 +2974,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
else:
model.build_in_name_scope() # build the network with dummy inputs
if safetensors_from_pt:
if safetensors_from_pt and not is_sharded:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
# Load from a PyTorch checkpoint
# Load from a PyTorch safetensors checkpoint
# We load in TF format here because PT weights often need to be transposed, and this is much
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
return load_pytorch_state_dict_in_tf2_model(
@ -2919,6 +2991,19 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
ignore_mismatched_sizes=ignore_mismatched_sizes,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)
elif safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
return load_sharded_pytorch_safetensors_in_tf2_model(
model,
resolved_archive_file,
tf_inputs=False,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
@ -2926,14 +3011,22 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
if is_sharded:
for file in resolved_archive_file:
os.path.isfile(file), f"Error retrieving files {file}"
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
if filename == SAFE_WEIGHTS_INDEX_NAME:
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
else:
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
else:
# Handles both H5 and safetensors
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
model,
resolved_archive_file,

View File

@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401
require_torch,
slow,
)
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
logging,
)
logger = logging.get_logger(__name__)
@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())
@require_safetensors
def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_safetensors_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".h5"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = TFBertModel.from_pretrained(tmp_dir)
model.build_in_name_scope()
new_model.build_in_name_scope()
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@slow
def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase):
model.save_pretrained(tmp_dir, safe_serialization=True)
# No tf_model.h5 file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_sharded_save_and_load(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
# No tf weights or index file, only a safetensors index
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@is_pt_tf_cross_test
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
# Check we have a safetensors shard index file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_load_from_hub(self):
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase):
@require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
# Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
# Confirm that we can access the TF weights too
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
@require_safetensors
def test_safetensors_load_from_local(self):