mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Support sharded safetensors in TF (#29350)
* Initial commit (still lots of unfinished bits) * (Still untested) add safetensors sharding to save_pretrained * Fix savetensors saving, update default shard size to match PT * Add proper loading of TF-format safetensors * Revert default size in case that changes things * Fix incorrect index name * Update loading priority * Update tests * Make the tests a little more stringent * Expand tests * Add sharded cross-test * Fix argument name * One more test fix * Adding mlx to the list of allowed formats * Remove irrelevant block for safetensors * Refactor warning logging into a separate function * Remove unused skip_logger_warnings arg * Update src/transformers/modeling_tf_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Move function def --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
870bbb4c6b
commit
11ef35e828
@ -21,10 +21,24 @@ import re
|
||||
|
||||
import numpy
|
||||
|
||||
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
expand_dims,
|
||||
is_numpy_array,
|
||||
is_safetensors_available,
|
||||
is_torch_tensor,
|
||||
logging,
|
||||
reshape,
|
||||
squeeze,
|
||||
tensor_size,
|
||||
)
|
||||
from .utils import transpose as transpose_func
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model(
|
||||
)
|
||||
|
||||
|
||||
def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
|
||||
f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
|
||||
f" {class_name} from a PyTorch model trained on another task or with another architecture"
|
||||
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
|
||||
f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
|
||||
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
|
||||
" BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
|
||||
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
|
||||
" down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {class_name} were initialized from the PyTorch model.\n"
|
||||
"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {class_name} for predictions without further training."
|
||||
)
|
||||
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {class_name} were not initialized from the model checkpoint"
|
||||
f" are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
|
||||
def load_pytorch_state_dict_in_tf2_model(
|
||||
tf_model,
|
||||
pt_state_dict,
|
||||
@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
_prefix=None,
|
||||
tf_to_pt_weight_rename=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
skip_logger_warnings=False,
|
||||
):
|
||||
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
||||
safetensors archive created with the safe_open() function."""
|
||||
@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
if tf_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in tf_model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
if not skip_logger_warnings:
|
||||
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
|
||||
f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
||||
f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
|
||||
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
|
||||
f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
|
||||
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
|
||||
" BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
|
||||
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
|
||||
" down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
|
||||
"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
}
|
||||
return tf_model, loading_info
|
||||
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
|
||||
f" are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
" to use it for predictions and inference."
|
||||
)
|
||||
return tf_model
|
||||
|
||||
|
||||
def load_sharded_pytorch_safetensors_in_tf2_model(
|
||||
tf_model,
|
||||
safetensors_shards,
|
||||
tf_inputs=None,
|
||||
allow_missing_keys=False,
|
||||
output_loading_info=False,
|
||||
_prefix=None,
|
||||
tf_to_pt_weight_rename=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
):
|
||||
all_loading_infos = []
|
||||
for shard in safetensors_shards:
|
||||
with safe_open(shard, framework="tf") as safetensors_archive:
|
||||
tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
|
||||
tf_model,
|
||||
safetensors_archive,
|
||||
tf_inputs=tf_inputs,
|
||||
allow_missing_keys=allow_missing_keys,
|
||||
output_loading_info=True,
|
||||
_prefix=_prefix,
|
||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
skip_logger_warnings=True, # We will emit merged warnings at the end
|
||||
)
|
||||
all_loading_infos.append(loading_info)
|
||||
# Now we just need to merge the loading info
|
||||
# Keys are missing only if they're missing in *every* shard
|
||||
missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
|
||||
# Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
|
||||
unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
|
||||
mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
|
||||
|
||||
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
|
@ -647,7 +647,7 @@ def strip_model_name_and_prefix(name, _prefix=None):
|
||||
return name
|
||||
|
||||
|
||||
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
||||
def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
|
||||
"""
|
||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||
given size.
|
||||
@ -695,13 +695,16 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
||||
|
||||
# If we only have one shard, we return it
|
||||
if len(sharded_state_dicts) == 1:
|
||||
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
||||
return {weights_name: sharded_state_dicts[0]}, None
|
||||
|
||||
# Otherwise, let's build the index
|
||||
weight_map = {}
|
||||
shards = {}
|
||||
for idx, shard in enumerate(sharded_state_dicts):
|
||||
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
|
||||
shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
|
||||
shard_file = shard_file.replace(
|
||||
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||
)
|
||||
shards[shard_file] = shard
|
||||
for weight in shard:
|
||||
weight_name = weight.name
|
||||
@ -782,7 +785,8 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
|
||||
|
||||
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
"""
|
||||
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
|
||||
Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
|
||||
Handles missing keys and unexpected keys.
|
||||
|
||||
Args:
|
||||
model (`keras.models.Model`): Model in which the weights are loaded
|
||||
@ -868,6 +872,61 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
|
||||
)
|
||||
|
||||
|
||||
def load_tf_sharded_weights_from_safetensors(
|
||||
model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
|
||||
):
|
||||
"""
|
||||
This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
|
||||
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
||||
shapes.
|
||||
|
||||
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||
loaded in the model.
|
||||
|
||||
Args:
|
||||
model (`keras.models.Model`): The model in which to load the checkpoint.
|
||||
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
|
||||
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
|
||||
Whether or not to ignore the mismatch between the sizes
|
||||
strict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||
|
||||
Returns:
|
||||
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
|
||||
mismatched layers.
|
||||
"""
|
||||
|
||||
# Load the index
|
||||
unexpected_keys = set()
|
||||
all_missing_keys = []
|
||||
mismatched_keys = set()
|
||||
|
||||
for shard_file in shard_files:
|
||||
missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
|
||||
model,
|
||||
shard_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=_prefix,
|
||||
)
|
||||
all_missing_keys.append(set(missing_layers))
|
||||
unexpected_keys.update(unexpected_layers)
|
||||
mismatched_keys.update(mismatched_layers)
|
||||
gc.collect()
|
||||
missing_keys = set.intersection(*all_missing_keys)
|
||||
|
||||
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
|
||||
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
||||
if len(missing_keys) > 0:
|
||||
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
|
||||
error_message += f"\nMissing key(s): {str_missing_keys}."
|
||||
if len(unexpected_keys) > 0:
|
||||
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
|
||||
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
return missing_keys, unexpected_keys, mismatched_keys
|
||||
|
||||
|
||||
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||
"""
|
||||
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
||||
@ -2303,7 +2362,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
version=1,
|
||||
push_to_hub=False,
|
||||
signatures=None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
max_shard_size: Union[int, str] = "5GB",
|
||||
create_pr: bool = False,
|
||||
safe_serialization: bool = False,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
@ -2415,7 +2474,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
|
||||
output_model_file = os.path.join(save_directory, weights_name)
|
||||
|
||||
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
|
||||
shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
@ -2438,7 +2497,8 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
self.save_weights(output_model_file)
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
else:
|
||||
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, save_index_file)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as index_file:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
@ -2449,19 +2509,25 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
for shard_file, shard in shards.items():
|
||||
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||
layers = []
|
||||
for layer in sorted(shard, key=lambda x: x.name):
|
||||
if "model." in layer.name or len(layer.name.split("/")) == 1:
|
||||
layer_name = layer.name
|
||||
else:
|
||||
layer_name = "/".join(layer.name.split("/")[1:])
|
||||
param_dset = shard_file.create_dataset(
|
||||
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
|
||||
)
|
||||
param_dset[:] = layer.numpy()
|
||||
layers.append(layer_name.encode("utf8"))
|
||||
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
|
||||
if safe_serialization:
|
||||
shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
|
||||
safe_save_file(
|
||||
shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
|
||||
)
|
||||
else:
|
||||
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||
layers = []
|
||||
for layer in sorted(shard, key=lambda x: x.name):
|
||||
if "model." in layer.name or len(layer.name.split("/")) == 1:
|
||||
layer_name = layer.name
|
||||
else:
|
||||
layer_name = "/".join(layer.name.split("/")[1:])
|
||||
param_dset = shard_file.create_dataset(
|
||||
layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
|
||||
)
|
||||
param_dset[:] = layer.numpy()
|
||||
layers.append(layer_name.encode("utf8"))
|
||||
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_modified_files(
|
||||
@ -2698,6 +2764,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
):
|
||||
# Load from a safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||
elif use_safetensors is not False and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
# Load from a sharded safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
@ -2705,17 +2777,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
# Load from a sharded TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
elif use_safetensors is not False and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
# Load from a sharded safetensors checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||
is_sharded = True
|
||||
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
||||
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif use_safetensors:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. "
|
||||
f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
|
||||
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
|
||||
f"set `use_safetensors=True`."
|
||||
)
|
||||
@ -2723,13 +2789,13 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
@ -2801,9 +2867,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
||||
is_sharded = True
|
||||
raise NotImplementedError(
|
||||
"Support for sharded checkpoints using safetensors is coming soon!"
|
||||
)
|
||||
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||
@ -2841,7 +2904,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||
if is_sharded:
|
||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||
resolved_archive_file, _ = get_checkpoint_shard_files(
|
||||
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
resolved_archive_file,
|
||||
cache_dir=cache_dir,
|
||||
@ -2859,7 +2922,16 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
if filename == SAFE_WEIGHTS_NAME:
|
||||
with safe_open(resolved_archive_file, framework="tf") as f:
|
||||
safetensors_metadata = f.metadata()
|
||||
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
|
||||
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
||||
" Make sure you save your model with the `save_pretrained` method."
|
||||
)
|
||||
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
|
||||
elif filename == SAFE_WEIGHTS_INDEX_NAME:
|
||||
with safe_open(resolved_archive_file[0], framework="tf") as f:
|
||||
safetensors_metadata = f.metadata()
|
||||
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
||||
" Make sure you save your model with the `save_pretrained` method."
|
||||
@ -2902,11 +2974,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
else:
|
||||
model.build_in_name_scope() # build the network with dummy inputs
|
||||
|
||||
if safetensors_from_pt:
|
||||
if safetensors_from_pt and not is_sharded:
|
||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||
|
||||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
||||
# Load from a PyTorch checkpoint
|
||||
# Load from a PyTorch safetensors checkpoint
|
||||
# We load in TF format here because PT weights often need to be transposed, and this is much
|
||||
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
|
||||
return load_pytorch_state_dict_in_tf2_model(
|
||||
@ -2919,6 +2991,19 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||
)
|
||||
elif safetensors_from_pt:
|
||||
from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
|
||||
|
||||
return load_sharded_pytorch_safetensors_in_tf2_model(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
tf_inputs=False,
|
||||
allow_missing_keys=True,
|
||||
output_loading_info=output_loading_info,
|
||||
_prefix=load_weight_prefix,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||
)
|
||||
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
@ -2926,14 +3011,22 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
||||
if is_sharded:
|
||||
for file in resolved_archive_file:
|
||||
os.path.isfile(file), f"Error retrieving files {file}"
|
||||
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=load_weight_prefix,
|
||||
)
|
||||
if filename == SAFE_WEIGHTS_INDEX_NAME:
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=load_weight_prefix,
|
||||
)
|
||||
else:
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
_prefix=load_weight_prefix,
|
||||
)
|
||||
else:
|
||||
# Handles both H5 and safetensors
|
||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
||||
model,
|
||||
resolved_archive_file,
|
||||
|
@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
def test_safetensors_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".h5"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model.build_in_name_scope()
|
||||
new_model.build_in_name_scope()
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@slow
|
||||
def test_save_pretrained_signatures(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No tf_model.h5 file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_sharded_save_and_load(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||
# No tf weights or index file, only a safetensors index
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||
# Check we have a safetensors shard index file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase):
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
||||
# Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
|
||||
# Confirm that we can access the TF weights too
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
|
Loading…
Reference in New Issue
Block a user