Support sharded safetensors in TF (#29350)

* Initial commit (still lots of unfinished bits)

* (Still untested) add safetensors sharding to save_pretrained

* Fix savetensors saving, update default shard size to match PT

* Add proper loading of TF-format safetensors

* Revert default size in case that changes things

* Fix incorrect index name

* Update loading priority

* Update tests

* Make the tests a little more stringent

* Expand tests

* Add sharded cross-test

* Fix argument name

* One more test fix

* Adding mlx to the list of allowed formats

* Remove irrelevant block for safetensors

* Refactor warning logging into a separate function

* Remove unused skip_logger_warnings arg

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Move function def

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt 2024-03-20 14:22:35 +00:00 committed by GitHub
parent 870bbb4c6b
commit 11ef35e828
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 324 additions and 86 deletions

View File

@ -21,10 +21,24 @@ import re
import numpy import numpy
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size from .utils import (
ExplicitEnum,
expand_dims,
is_numpy_array,
is_safetensors_available,
is_torch_tensor,
logging,
reshape,
squeeze,
tensor_size,
)
from .utils import transpose as transpose_func from .utils import transpose as transpose_func
if is_safetensors_available():
from safetensors import safe_open
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model(
) )
def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {class_name} from a PyTorch model trained on another task or with another architecture"
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {class_name} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {class_name} for predictions without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {class_name} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
def load_pytorch_state_dict_in_tf2_model( def load_pytorch_state_dict_in_tf2_model(
tf_model, tf_model,
pt_state_dict, pt_state_dict,
@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model(
_prefix=None, _prefix=None,
tf_to_pt_weight_rename=None, tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False, ignore_mismatched_sizes=False,
skip_logger_warnings=False,
): ):
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function.""" safetensors archive created with the safe_open() function."""
@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model(
if tf_model._keys_to_ignore_on_load_unexpected is not None: if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model._keys_to_ignore_on_load_unexpected: for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if not skip_logger_warnings:
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
if len(unexpected_keys) > 0: if output_loading_info:
logger.warning( loading_info = {
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model" "missing_keys": missing_keys,
f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" "unexpected_keys": unexpected_keys,
f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture" "mismatched_keys": mismatched_keys,
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" }
f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect" return tf_model, loading_info
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
if len(mismatched_keys) > 0: return tf_model
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" def load_sharded_pytorch_safetensors_in_tf2_model(
for key, shape1, shape2 in mismatched_keys tf_model,
] safetensors_shards,
) tf_inputs=None,
logger.warning( allow_missing_keys=False,
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint" output_loading_info=False,
f" are newly initialized because the shapes did not" _prefix=None,
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" tf_to_pt_weight_rename=None,
" to use it for predictions and inference." ignore_mismatched_sizes=False,
):
all_loading_infos = []
for shard in safetensors_shards:
with safe_open(shard, framework="tf") as safetensors_archive:
tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
tf_model,
safetensors_archive,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=True,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
ignore_mismatched_sizes=ignore_mismatched_sizes,
skip_logger_warnings=True, # We will emit merged warnings at the end
) )
all_loading_infos.append(loading_info)
# Now we just need to merge the loading info
# Keys are missing only if they're missing in *every* shard
missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
# Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
if output_loading_info: if output_loading_info:
loading_info = { loading_info = {

View File

@ -647,7 +647,7 @@ def strip_model_name_and_prefix(name, _prefix=None):
return name return name
def tf_shard_checkpoint(weights, max_shard_size="10GB"): def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. given size.
@ -695,13 +695,16 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
# If we only have one shard, we return it # If we only have one shard, we return it
if len(sharded_state_dicts) == 1: if len(sharded_state_dicts) == 1:
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index # Otherwise, let's build the index
weight_map = {} weight_map = {}
shards = {} shards = {}
for idx, shard in enumerate(sharded_state_dicts): for idx, shard in enumerate(sharded_state_dicts):
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5") shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard shards[shard_file] = shard
for weight in shard: for weight in shard:
weight_name = weight.name weight_name = weight.name
@ -782,7 +785,8 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
""" """
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys. Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
Handles missing keys and unexpected keys.
Args: Args:
model (`keras.models.Model`): Model in which the weights are loaded model (`keras.models.Model`): Model in which the weights are loaded
@ -868,6 +872,61 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
) )
def load_tf_sharded_weights_from_safetensors(
model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
):
"""
This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
shapes.
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
model (`keras.models.Model`): The model in which to load the checkpoint.
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
Whether or not to ignore the mismatch between the sizes
strict (`bool`, *optional*, defaults to `True`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
Returns:
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
mismatched layers.
"""
# Load the index
unexpected_keys = set()
all_missing_keys = []
mismatched_keys = set()
for shard_file in shard_files:
missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
model,
shard_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=_prefix,
)
all_missing_keys.append(set(missing_layers))
unexpected_keys.update(unexpected_layers)
mismatched_keys.update(mismatched_layers)
gc.collect()
missing_keys = set.intersection(*all_missing_keys)
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if len(missing_keys) > 0:
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
error_message += f"\nMissing key(s): {str_missing_keys}."
if len(unexpected_keys) > 0:
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
return missing_keys, unexpected_keys, mismatched_keys
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
""" """
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
@ -2303,7 +2362,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
version=1, version=1,
push_to_hub=False, push_to_hub=False,
signatures=None, signatures=None,
max_shard_size: Union[int, str] = "10GB", max_shard_size: Union[int, str] = "5GB",
create_pr: bool = False, create_pr: bool = False,
safe_serialization: bool = False, safe_serialization: bool = False,
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
@ -2415,7 +2474,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
output_model_file = os.path.join(save_directory, weights_name) output_model_file = os.path.join(save_directory, weights_name)
shards, index = tf_shard_checkpoint(self.weights, max_shard_size) shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
# Clean the folder from a previous save # Clean the folder from a previous save
for filename in os.listdir(save_directory): for filename in os.listdir(save_directory):
@ -2438,7 +2497,8 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
self.save_weights(output_model_file) self.save_weights(output_model_file)
logger.info(f"Model weights saved in {output_model_file}") logger.info(f"Model weights saved in {output_model_file}")
else: else:
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME) save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file)
# Save the index as well # Save the index as well
with open(save_index_file, "w", encoding="utf-8") as index_file: with open(save_index_file, "w", encoding="utf-8") as index_file:
content = json.dumps(index, indent=2, sort_keys=True) + "\n" content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@ -2449,6 +2509,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
f"index located at {save_index_file}." f"index located at {save_index_file}."
) )
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
if safe_serialization:
shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
safe_save_file(
shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
)
else:
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
layers = [] layers = []
for layer in sorted(shard, key=lambda x: x.name): for layer in sorted(shard, key=lambda x: x.name):
@ -2698,6 +2764,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
): ):
# Load from a safetensors checkpoint # Load from a safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
# Load from a TF 2.0 checkpoint # Load from a TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
@ -2705,17 +2777,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
# Load from a sharded TF 2.0 checkpoint # Load from a sharded TF 2.0 checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
elif use_safetensors is not False and os.path.isfile(
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
):
# Load from a sharded safetensors checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
is_sharded = True
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif use_safetensors: elif use_safetensors:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. " f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
f"Please make sure that the model has been saved with `safe_serialization=True` or do not " f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
f"set `use_safetensors=True`." f"set `use_safetensors=True`."
) )
@ -2723,13 +2789,13 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
): ):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
"weights." "weights."
) )
else: else:
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}." f"{pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path):
@ -2801,9 +2867,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
} }
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
is_sharded = True is_sharded = True
raise NotImplementedError(
"Support for sharded checkpoints using safetensors is coming soon!"
)
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named" f"{pretrained_model_name_or_path} does not appear to have a file named"
@ -2841,7 +2904,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded. # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded: if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, _ = get_checkpoint_shard_files( resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path, pretrained_model_name_or_path,
resolved_archive_file, resolved_archive_file,
cache_dir=cache_dir, cache_dir=cache_dir,
@ -2859,7 +2922,16 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
if filename == SAFE_WEIGHTS_NAME: if filename == SAFE_WEIGHTS_NAME:
with safe_open(resolved_archive_file, framework="tf") as f: with safe_open(resolved_archive_file, framework="tf") as f:
safetensors_metadata = f.metadata() safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method."
)
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
elif filename == SAFE_WEIGHTS_INDEX_NAME:
with safe_open(resolved_archive_file[0], framework="tf") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError( raise OSError(
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
" Make sure you save your model with the `save_pretrained` method." " Make sure you save your model with the `save_pretrained` method."
@ -2902,11 +2974,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
else: else:
model.build_in_name_scope() # build the network with dummy inputs model.build_in_name_scope() # build the network with dummy inputs
if safetensors_from_pt: if safetensors_from_pt and not is_sharded:
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
# Load from a PyTorch checkpoint # Load from a PyTorch safetensors checkpoint
# We load in TF format here because PT weights often need to be transposed, and this is much # We load in TF format here because PT weights often need to be transposed, and this is much
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times. # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
return load_pytorch_state_dict_in_tf2_model( return load_pytorch_state_dict_in_tf2_model(
@ -2919,6 +2991,19 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
tf_to_pt_weight_rename=tf_to_pt_weight_rename, tf_to_pt_weight_rename=tf_to_pt_weight_rename,
) )
elif safetensors_from_pt:
from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
return load_sharded_pytorch_safetensors_in_tf2_model(
model,
resolved_archive_file,
tf_inputs=False,
allow_missing_keys=True,
output_loading_info=output_loading_info,
_prefix=load_weight_prefix,
ignore_mismatched_sizes=ignore_mismatched_sizes,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
@ -2926,7 +3011,14 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
if is_sharded: if is_sharded:
for file in resolved_archive_file: for file in resolved_archive_file:
os.path.isfile(file), f"Error retrieving files {file}" os.path.isfile(file), f"Error retrieving files {file}"
if filename == SAFE_WEIGHTS_INDEX_NAME:
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
else:
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
model, model,
resolved_archive_file, resolved_archive_file,
@ -2934,6 +3026,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
_prefix=load_weight_prefix, _prefix=load_weight_prefix,
) )
else: else:
# Handles both H5 and safetensors
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
model, model,
resolved_archive_file, resolved_archive_file,

View File

@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401
require_torch, require_torch,
slow, slow,
) )
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging from transformers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
logging,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, ref_model.weights): for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy()) assert np.allclose(p1.numpy(), p2.numpy())
@require_safetensors
def test_checkpoint_sharding_local(self): def test_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
def test_safetensors_checkpoint_sharding_local(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
# Get each shard file and its size
shard_to_size = {}
for shard in os.listdir(tmp_dir):
if shard.endswith(".h5"):
shard_file = os.path.join(tmp_dir, shard)
shard_to_size[shard_file] = os.path.getsize(shard_file)
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
# Check there is an index but no regular weight file
self.assertTrue(os.path.isfile(index_file))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
# Check the index and the shard files found match
with open(index_file, "r", encoding="utf-8") as f:
index = json.loads(f.read())
all_shards = set(index["weight_map"].values())
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
self.assertSetEqual(all_shards, shards_found)
# Finally, check the model can be reloaded
new_model = TFBertModel.from_pretrained(tmp_dir)
model.build_in_name_scope()
new_model.build_in_name_scope()
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@slow @slow
def test_save_pretrained_signatures(self): def test_save_pretrained_signatures(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase):
model.save_pretrained(tmp_dir, safe_serialization=True) model.save_pretrained(tmp_dir, safe_serialization=True)
# No tf_model.h5 file, only a model.safetensors # No tf_model.h5 file, only a model.safetensors
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME))) self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors
def test_safetensors_sharded_save_and_load(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
# No tf weights or index file, only a safetensors index
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir) new_model = TFBertModel.from_pretrained(tmp_dir)
@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase):
for p1, p2 in zip(model.weights, new_model.weights): for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@is_pt_tf_cross_test
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with tempfile.TemporaryDirectory() as tmp_dir:
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
# Check we have a safetensors shard index file
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
new_model = TFBertModel.from_pretrained(tmp_dir)
# Check models are equal
for p1, p2 in zip(model.weights, new_model.weights):
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
@require_safetensors @require_safetensors
def test_safetensors_load_from_hub(self): def test_safetensors_load_from_hub(self):
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase):
@require_safetensors @require_safetensors
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self): def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
# This should not raise even if there are two types of sharded weights # Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
# This should discard the safetensors weights in favor of the .h5 sharded weights TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") # Confirm that we can access the TF weights too
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
@require_safetensors @require_safetensors
def test_safetensors_load_from_local(self): def test_safetensors_load_from_local(self):