mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-15 18:48:24 +06:00
Support sharded safetensors in TF (#29350)
* Initial commit (still lots of unfinished bits) * (Still untested) add safetensors sharding to save_pretrained * Fix savetensors saving, update default shard size to match PT * Add proper loading of TF-format safetensors * Revert default size in case that changes things * Fix incorrect index name * Update loading priority * Update tests * Make the tests a little more stringent * Expand tests * Add sharded cross-test * Fix argument name * One more test fix * Adding mlx to the list of allowed formats * Remove irrelevant block for safetensors * Refactor warning logging into a separate function * Remove unused skip_logger_warnings arg * Update src/transformers/modeling_tf_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Move function def --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
870bbb4c6b
commit
11ef35e828
@ -21,10 +21,24 @@ import re
|
|||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
|
from .utils import (
|
||||||
|
ExplicitEnum,
|
||||||
|
expand_dims,
|
||||||
|
is_numpy_array,
|
||||||
|
is_safetensors_available,
|
||||||
|
is_torch_tensor,
|
||||||
|
logging,
|
||||||
|
reshape,
|
||||||
|
squeeze,
|
||||||
|
tensor_size,
|
||||||
|
)
|
||||||
from .utils import transpose as transpose_func
|
from .utils import transpose as transpose_func
|
||||||
|
|
||||||
|
|
||||||
|
if is_safetensors_available():
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
|
||||||
|
f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
|
||||||
|
f" {class_name} from a PyTorch model trained on another task or with another architecture"
|
||||||
|
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
|
||||||
|
f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
|
||||||
|
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
|
||||||
|
" BertForSequenceClassification model)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
|
||||||
|
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
|
||||||
|
" down-stream task to be able to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"All the weights of {class_name} were initialized from the PyTorch model.\n"
|
||||||
|
"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
|
f"you can already use {class_name} for predictions without further training."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(mismatched_keys) > 0:
|
||||||
|
mismatched_warning = "\n".join(
|
||||||
|
[
|
||||||
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||||
|
for key, shape1, shape2 in mismatched_keys
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {class_name} were not initialized from the model checkpoint"
|
||||||
|
f" are newly initialized because the shapes did not"
|
||||||
|
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||||
|
" to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_state_dict_in_tf2_model(
|
def load_pytorch_state_dict_in_tf2_model(
|
||||||
tf_model,
|
tf_model,
|
||||||
pt_state_dict,
|
pt_state_dict,
|
||||||
@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
_prefix=None,
|
_prefix=None,
|
||||||
tf_to_pt_weight_rename=None,
|
tf_to_pt_weight_rename=None,
|
||||||
ignore_mismatched_sizes=False,
|
ignore_mismatched_sizes=False,
|
||||||
|
skip_logger_warnings=False,
|
||||||
):
|
):
|
||||||
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
|
||||||
safetensors archive created with the safe_open() function."""
|
safetensors archive created with the safe_open() function."""
|
||||||
@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model(
|
|||||||
if tf_model._keys_to_ignore_on_load_unexpected is not None:
|
if tf_model._keys_to_ignore_on_load_unexpected is not None:
|
||||||
for pat in tf_model._keys_to_ignore_on_load_unexpected:
|
for pat in tf_model._keys_to_ignore_on_load_unexpected:
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
if not skip_logger_warnings:
|
||||||
|
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if output_loading_info:
|
||||||
logger.warning(
|
loading_info = {
|
||||||
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
|
"missing_keys": missing_keys,
|
||||||
f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
"unexpected_keys": unexpected_keys,
|
||||||
f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
|
"mismatched_keys": mismatched_keys,
|
||||||
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
|
}
|
||||||
f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
|
return tf_model, loading_info
|
||||||
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
|
|
||||||
" BertForSequenceClassification model)."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
|
|
||||||
if len(missing_keys) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
|
|
||||||
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
|
|
||||||
" down-stream task to be able to use it for predictions and inference."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
|
|
||||||
"If your task is similar to the task the model of the checkpoint was trained on, "
|
|
||||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(mismatched_keys) > 0:
|
return tf_model
|
||||||
mismatched_warning = "\n".join(
|
|
||||||
[
|
|
||||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
def load_sharded_pytorch_safetensors_in_tf2_model(
|
||||||
for key, shape1, shape2 in mismatched_keys
|
tf_model,
|
||||||
]
|
safetensors_shards,
|
||||||
)
|
tf_inputs=None,
|
||||||
logger.warning(
|
allow_missing_keys=False,
|
||||||
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
|
output_loading_info=False,
|
||||||
f" are newly initialized because the shapes did not"
|
_prefix=None,
|
||||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
tf_to_pt_weight_rename=None,
|
||||||
" to use it for predictions and inference."
|
ignore_mismatched_sizes=False,
|
||||||
|
):
|
||||||
|
all_loading_infos = []
|
||||||
|
for shard in safetensors_shards:
|
||||||
|
with safe_open(shard, framework="tf") as safetensors_archive:
|
||||||
|
tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
|
||||||
|
tf_model,
|
||||||
|
safetensors_archive,
|
||||||
|
tf_inputs=tf_inputs,
|
||||||
|
allow_missing_keys=allow_missing_keys,
|
||||||
|
output_loading_info=True,
|
||||||
|
_prefix=_prefix,
|
||||||
|
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
skip_logger_warnings=True, # We will emit merged warnings at the end
|
||||||
)
|
)
|
||||||
|
all_loading_infos.append(loading_info)
|
||||||
|
# Now we just need to merge the loading info
|
||||||
|
# Keys are missing only if they're missing in *every* shard
|
||||||
|
missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
|
||||||
|
# Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
|
||||||
|
unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
|
||||||
|
mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
|
||||||
|
|
||||||
|
_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
|
||||||
|
|
||||||
if output_loading_info:
|
if output_loading_info:
|
||||||
loading_info = {
|
loading_info = {
|
||||||
|
@ -647,7 +647,7 @@ def strip_model_name_and_prefix(name, _prefix=None):
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
|
||||||
"""
|
"""
|
||||||
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
||||||
given size.
|
given size.
|
||||||
@ -695,13 +695,16 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"):
|
|||||||
|
|
||||||
# If we only have one shard, we return it
|
# If we only have one shard, we return it
|
||||||
if len(sharded_state_dicts) == 1:
|
if len(sharded_state_dicts) == 1:
|
||||||
return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
return {weights_name: sharded_state_dicts[0]}, None
|
||||||
|
|
||||||
# Otherwise, let's build the index
|
# Otherwise, let's build the index
|
||||||
weight_map = {}
|
weight_map = {}
|
||||||
shards = {}
|
shards = {}
|
||||||
for idx, shard in enumerate(sharded_state_dicts):
|
for idx, shard in enumerate(sharded_state_dicts):
|
||||||
shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
|
shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5")
|
||||||
|
shard_file = shard_file.replace(
|
||||||
|
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
|
||||||
|
)
|
||||||
shards[shard_file] = shard
|
shards[shard_file] = shard
|
||||||
for weight in shard:
|
for weight in shard:
|
||||||
weight_name = weight.name
|
weight_name = weight.name
|
||||||
@ -782,7 +785,8 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
|
|||||||
|
|
||||||
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||||
"""
|
"""
|
||||||
Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys.
|
Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
|
||||||
|
Handles missing keys and unexpected keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (`keras.models.Model`): Model in which the weights are loaded
|
model (`keras.models.Model`): Model in which the weights are loaded
|
||||||
@ -868,6 +872,61 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf_sharded_weights_from_safetensors(
|
||||||
|
model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
|
||||||
|
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
||||||
|
shapes.
|
||||||
|
|
||||||
|
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
||||||
|
loaded in the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (`keras.models.Model`): The model in which to load the checkpoint.
|
||||||
|
shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
|
||||||
|
ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
|
||||||
|
Whether or not to ignore the mismatch between the sizes
|
||||||
|
strict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
|
||||||
|
mismatched layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Load the index
|
||||||
|
unexpected_keys = set()
|
||||||
|
all_missing_keys = []
|
||||||
|
mismatched_keys = set()
|
||||||
|
|
||||||
|
for shard_file in shard_files:
|
||||||
|
missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
|
||||||
|
model,
|
||||||
|
shard_file,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
_prefix=_prefix,
|
||||||
|
)
|
||||||
|
all_missing_keys.append(set(missing_layers))
|
||||||
|
unexpected_keys.update(unexpected_layers)
|
||||||
|
mismatched_keys.update(mismatched_layers)
|
||||||
|
gc.collect()
|
||||||
|
missing_keys = set.intersection(*all_missing_keys)
|
||||||
|
|
||||||
|
if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
|
||||||
|
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
|
||||||
|
error_message += f"\nMissing key(s): {str_missing_keys}."
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
|
||||||
|
error_message += f"\nMissing key(s): {str_unexpected_keys}."
|
||||||
|
raise RuntimeError(error_message)
|
||||||
|
|
||||||
|
return missing_keys, unexpected_keys, mismatched_keys
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
|
||||||
"""
|
"""
|
||||||
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
|
||||||
@ -2303,7 +2362,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
version=1,
|
version=1,
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
signatures=None,
|
signatures=None,
|
||||||
max_shard_size: Union[int, str] = "10GB",
|
max_shard_size: Union[int, str] = "5GB",
|
||||||
create_pr: bool = False,
|
create_pr: bool = False,
|
||||||
safe_serialization: bool = False,
|
safe_serialization: bool = False,
|
||||||
token: Optional[Union[str, bool]] = None,
|
token: Optional[Union[str, bool]] = None,
|
||||||
@ -2415,7 +2474,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
|
||||||
output_model_file = os.path.join(save_directory, weights_name)
|
output_model_file = os.path.join(save_directory, weights_name)
|
||||||
|
|
||||||
shards, index = tf_shard_checkpoint(self.weights, max_shard_size)
|
shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
|
||||||
|
|
||||||
# Clean the folder from a previous save
|
# Clean the folder from a previous save
|
||||||
for filename in os.listdir(save_directory):
|
for filename in os.listdir(save_directory):
|
||||||
@ -2438,7 +2497,8 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
self.save_weights(output_model_file)
|
self.save_weights(output_model_file)
|
||||||
logger.info(f"Model weights saved in {output_model_file}")
|
logger.info(f"Model weights saved in {output_model_file}")
|
||||||
else:
|
else:
|
||||||
save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME)
|
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
|
||||||
|
save_index_file = os.path.join(save_directory, save_index_file)
|
||||||
# Save the index as well
|
# Save the index as well
|
||||||
with open(save_index_file, "w", encoding="utf-8") as index_file:
|
with open(save_index_file, "w", encoding="utf-8") as index_file:
|
||||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
@ -2449,6 +2509,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
for shard_file, shard in shards.items():
|
for shard_file, shard in shards.items():
|
||||||
|
if safe_serialization:
|
||||||
|
shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
|
||||||
|
safe_save_file(
|
||||||
|
shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
|
||||||
layers = []
|
layers = []
|
||||||
for layer in sorted(shard, key=lambda x: x.name):
|
for layer in sorted(shard, key=lambda x: x.name):
|
||||||
@ -2698,6 +2764,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
):
|
):
|
||||||
# Load from a safetensors checkpoint
|
# Load from a safetensors checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
||||||
|
elif use_safetensors is not False and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
):
|
||||||
|
# Load from a sharded safetensors checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
is_sharded = True
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
# Load from a TF 2.0 checkpoint
|
# Load from a TF 2.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
@ -2705,17 +2777,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
# Load from a sharded TF 2.0 checkpoint
|
# Load from a sharded TF 2.0 checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
elif use_safetensors is not False and os.path.isfile(
|
|
||||||
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
):
|
|
||||||
# Load from a sharded safetensors checkpoint
|
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
||||||
is_sharded = True
|
|
||||||
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
|
||||||
# At this stage we don't have a weight file so we will raise an error.
|
# At this stage we don't have a weight file so we will raise an error.
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. "
|
f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
|
||||||
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
|
f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
|
||||||
f"set `use_safetensors=True`."
|
f"set `use_safetensors=True`."
|
||||||
)
|
)
|
||||||
@ -2723,13 +2789,13 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
|
||||||
):
|
):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||||
"weights."
|
"weights."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||||
f"{pretrained_model_name_or_path}."
|
f"{pretrained_model_name_or_path}."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
@ -2801,9 +2867,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
}
|
}
|
||||||
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
raise NotImplementedError(
|
|
||||||
"Support for sharded checkpoints using safetensors is coming soon!"
|
|
||||||
)
|
|
||||||
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
@ -2841,7 +2904,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
||||||
if is_sharded:
|
if is_sharded:
|
||||||
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
||||||
resolved_archive_file, _ = get_checkpoint_shard_files(
|
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@ -2859,7 +2922,16 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
if filename == SAFE_WEIGHTS_NAME:
|
if filename == SAFE_WEIGHTS_NAME:
|
||||||
with safe_open(resolved_archive_file, framework="tf") as f:
|
with safe_open(resolved_archive_file, framework="tf") as f:
|
||||||
safetensors_metadata = f.metadata()
|
safetensors_metadata = f.metadata()
|
||||||
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
|
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||||
|
raise OSError(
|
||||||
|
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
||||||
|
" Make sure you save your model with the `save_pretrained` method."
|
||||||
|
)
|
||||||
|
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
|
||||||
|
elif filename == SAFE_WEIGHTS_INDEX_NAME:
|
||||||
|
with safe_open(resolved_archive_file[0], framework="tf") as f:
|
||||||
|
safetensors_metadata = f.metadata()
|
||||||
|
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
|
||||||
raise OSError(
|
raise OSError(
|
||||||
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
||||||
" Make sure you save your model with the `save_pretrained` method."
|
" Make sure you save your model with the `save_pretrained` method."
|
||||||
@ -2902,11 +2974,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
else:
|
else:
|
||||||
model.build_in_name_scope() # build the network with dummy inputs
|
model.build_in_name_scope() # build the network with dummy inputs
|
||||||
|
|
||||||
if safetensors_from_pt:
|
if safetensors_from_pt and not is_sharded:
|
||||||
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
|
||||||
|
|
||||||
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch safetensors checkpoint
|
||||||
# We load in TF format here because PT weights often need to be transposed, and this is much
|
# We load in TF format here because PT weights often need to be transposed, and this is much
|
||||||
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
|
# faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
|
||||||
return load_pytorch_state_dict_in_tf2_model(
|
return load_pytorch_state_dict_in_tf2_model(
|
||||||
@ -2919,6 +2991,19 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||||
)
|
)
|
||||||
|
elif safetensors_from_pt:
|
||||||
|
from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
|
||||||
|
|
||||||
|
return load_sharded_pytorch_safetensors_in_tf2_model(
|
||||||
|
model,
|
||||||
|
resolved_archive_file,
|
||||||
|
tf_inputs=False,
|
||||||
|
allow_missing_keys=True,
|
||||||
|
output_loading_info=output_loading_info,
|
||||||
|
_prefix=load_weight_prefix,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
|
||||||
|
)
|
||||||
|
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
@ -2926,7 +3011,14 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
if is_sharded:
|
if is_sharded:
|
||||||
for file in resolved_archive_file:
|
for file in resolved_archive_file:
|
||||||
os.path.isfile(file), f"Error retrieving files {file}"
|
os.path.isfile(file), f"Error retrieving files {file}"
|
||||||
|
if filename == SAFE_WEIGHTS_INDEX_NAME:
|
||||||
|
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
|
||||||
|
model,
|
||||||
|
resolved_archive_file,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
_prefix=load_weight_prefix,
|
||||||
|
)
|
||||||
|
else:
|
||||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
|
missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
|
||||||
model,
|
model,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
@ -2934,6 +3026,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT
|
|||||||
_prefix=load_weight_prefix,
|
_prefix=load_weight_prefix,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Handles both H5 and safetensors
|
||||||
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
|
||||||
model,
|
model,
|
||||||
resolved_archive_file,
|
resolved_archive_file,
|
||||||
|
@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401
|
|||||||
require_torch,
|
require_torch,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
TF2_WEIGHTS_INDEX_NAME,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
assert np.allclose(p1.numpy(), p2.numpy())
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
def test_checkpoint_sharding_local(self):
|
def test_checkpoint_sharding_local(self):
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
def test_safetensors_checkpoint_sharding_local(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
|
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
|
||||||
|
|
||||||
|
# Get each shard file and its size
|
||||||
|
shard_to_size = {}
|
||||||
|
for shard in os.listdir(tmp_dir):
|
||||||
|
if shard.endswith(".h5"):
|
||||||
|
shard_file = os.path.join(tmp_dir, shard)
|
||||||
|
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||||
|
|
||||||
|
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||||
|
# Check there is an index but no regular weight file
|
||||||
|
self.assertTrue(os.path.isfile(index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
# Check the index and the shard files found match
|
||||||
|
with open(index_file, "r", encoding="utf-8") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
all_shards = set(index["weight_map"].values())
|
||||||
|
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
|
||||||
|
self.assertSetEqual(all_shards, shards_found)
|
||||||
|
|
||||||
|
# Finally, check the model can be reloaded
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model.build_in_name_scope()
|
||||||
|
new_model.build_in_name_scope()
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_save_pretrained_signatures(self):
|
def test_save_pretrained_signatures(self):
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
# No tf_model.h5 file, only a model.safetensors
|
# No tf_model.h5 file, only a model.safetensors
|
||||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_sharded_save_and_load(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||||
|
# No tf weights or index file, only a safetensors index
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||||
|
# Check we have a safetensors shard index file
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_load_from_hub(self):
|
def test_safetensors_load_from_hub(self):
|
||||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
||||||
# This should not raise even if there are two types of sharded weights
|
# Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
|
||||||
# This should discard the safetensors weights in favor of the .h5 sharded weights
|
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
|
||||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
|
# Confirm that we can access the TF weights too
|
||||||
|
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
|
||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_safetensors_load_from_local(self):
|
def test_safetensors_load_from_local(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user