diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 99fa1066748..a1663796dec 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -21,10 +21,24 @@ import re import numpy -from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size +from .utils import ( + ExplicitEnum, + expand_dims, + is_numpy_array, + is_safetensors_available, + is_torch_tensor, + logging, + reshape, + squeeze, + tensor_size, +) from .utils import transpose as transpose_func +if is_safetensors_available(): + from safetensors import safe_open + + logger = logging.get_logger(__name__) @@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model( ) +def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" + f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing" + f" {class_name} from a PyTorch model trained on another task or with another architecture" + " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" + f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect" + " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" + " BertForSequenceClassification model)." + ) + else: + logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the" + f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" + " down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {class_name} were initialized from the PyTorch model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {class_name} for predictions without further training." + ) + + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {class_name} were not initialized from the model checkpoint" + f" are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + def load_pytorch_state_dict_in_tf2_model( tf_model, pt_state_dict, @@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model( _prefix=None, tf_to_pt_weight_rename=None, ignore_mismatched_sizes=False, + skip_logger_warnings=False, ): """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading safetensors archive created with the safe_open() function.""" @@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model( if tf_model._keys_to_ignore_on_load_unexpected is not None: for pat in tf_model._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if not skip_logger_warnings: + _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) - if len(unexpected_keys) > 0: - logger.warning( - "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" - f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" - f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture" - " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" - f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect" - " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" - " BertForSequenceClassification model)." - ) - else: - logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n") - if len(missing_keys) > 0: - logger.warning( - f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the" - f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" - " down-stream task to be able to use it for predictions and inference." - ) - else: - logger.warning( - f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n" - "If your task is similar to the task the model of the checkpoint was trained on, " - f"you can already use {tf_model.__class__.__name__} for predictions without further training." - ) + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + } + return tf_model, loading_info - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, shape1, shape2 in mismatched_keys - ] - ) - logger.warning( - f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint" - f" are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" - " to use it for predictions and inference." - ) + return tf_model + + +def load_sharded_pytorch_safetensors_in_tf2_model( + tf_model, + safetensors_shards, + tf_inputs=None, + allow_missing_keys=False, + output_loading_info=False, + _prefix=None, + tf_to_pt_weight_rename=None, + ignore_mismatched_sizes=False, +): + all_loading_infos = [] + for shard in safetensors_shards: + with safe_open(shard, framework="tf") as safetensors_archive: + tf_model, loading_info = load_pytorch_state_dict_in_tf2_model( + tf_model, + safetensors_archive, + tf_inputs=tf_inputs, + allow_missing_keys=allow_missing_keys, + output_loading_info=True, + _prefix=_prefix, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ignore_mismatched_sizes=ignore_mismatched_sizes, + skip_logger_warnings=True, # We will emit merged warnings at the end + ) + all_loading_infos.append(loading_info) + # Now we just need to merge the loading info + # Keys are missing only if they're missing in *every* shard + missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos])) + # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard + unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) + mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) + + _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) if output_loading_info: loading_info = { diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a1de9a1cdb8..d5e17d25686 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -647,7 +647,7 @@ def strip_model_name_and_prefix(name, _prefix=None): return name -def tf_shard_checkpoint(weights, max_shard_size="10GB"): +def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -695,13 +695,16 @@ def tf_shard_checkpoint(weights, max_shard_size="10GB"): # If we only have one shard, we return it if len(sharded_state_dicts) == 1: - return {TF2_WEIGHTS_NAME: sharded_state_dicts[0]}, None + return {weights_name: sharded_state_dicts[0]}, None # Otherwise, let's build the index weight_map = {} shards = {} for idx, shard in enumerate(sharded_state_dicts): - shard_file = TF2_WEIGHTS_NAME.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5") + shard_file = weights_name.replace(".h5", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.h5") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) shards[shard_file] = shard for weight in shard: weight_name = weight.name @@ -782,7 +785,8 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): """ - Loads a shard from a sharded checkpoint file. Handles the missing keys and unexpected keys. + Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors. + Handles missing keys and unexpected keys. Args: model (`keras.models.Model`): Model in which the weights are loaded @@ -868,6 +872,61 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch ) +def load_tf_sharded_weights_from_safetensors( + model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None +): + """ + This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint. + Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and + shapes. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`keras.models.Model`): The model in which to load the checkpoint. + shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names. + ignore_mismatched_sizes`bool`, *optional`, defaults to `True`): + Whether or not to ignore the mismatch between the sizes + strict (`bool`, *optional*, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + + Returns: + Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the + mismatched layers. + """ + + # Load the index + unexpected_keys = set() + all_missing_keys = [] + mismatched_keys = set() + + for shard_file in shard_files: + missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors( + model, + shard_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=_prefix, + ) + all_missing_keys.append(set(missing_layers)) + unexpected_keys.update(unexpected_layers) + mismatched_keys.update(mismatched_layers) + gc.collect() + missing_keys = set.intersection(*all_missing_keys) + + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + return missing_keys, unexpected_keys, mismatched_keys + + def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None): """ Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and @@ -2303,7 +2362,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT version=1, push_to_hub=False, signatures=None, - max_shard_size: Union[int, str] = "10GB", + max_shard_size: Union[int, str] = "5GB", create_pr: bool = False, safe_serialization: bool = False, token: Optional[Union[str, bool]] = None, @@ -2415,7 +2474,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME output_model_file = os.path.join(save_directory, weights_name) - shards, index = tf_shard_checkpoint(self.weights, max_shard_size) + shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name) # Clean the folder from a previous save for filename in os.listdir(save_directory): @@ -2438,7 +2497,8 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT self.save_weights(output_model_file) logger.info(f"Model weights saved in {output_model_file}") else: - save_index_file = os.path.join(save_directory, TF2_WEIGHTS_INDEX_NAME) + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, save_index_file) # Save the index as well with open(save_index_file, "w", encoding="utf-8") as index_file: content = json.dumps(index, indent=2, sort_keys=True) + "\n" @@ -2449,19 +2509,25 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT f"index located at {save_index_file}." ) for shard_file, shard in shards.items(): - with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: - layers = [] - for layer in sorted(shard, key=lambda x: x.name): - if "model." in layer.name or len(layer.name.split("/")) == 1: - layer_name = layer.name - else: - layer_name = "/".join(layer.name.split("/")[1:]) - param_dset = shard_file.create_dataset( - layer_name, layer.numpy().shape, dtype=layer.numpy().dtype - ) - param_dset[:] = layer.numpy() - layers.append(layer_name.encode("utf8")) - save_attributes_to_hdf5_group(shard_file, "layer_names", layers) + if safe_serialization: + shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard} + safe_save_file( + shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"} + ) + else: + with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file: + layers = [] + for layer in sorted(shard, key=lambda x: x.name): + if "model." in layer.name or len(layer.name.split("/")) == 1: + layer_name = layer.name + else: + layer_name = "/".join(layer.name.split("/")[1:]) + param_dset = shard_file.create_dataset( + layer_name, layer.numpy().shape, dtype=layer.numpy().dtype + ) + param_dset[:] = layer.numpy() + layers.append(layer_name.encode("utf8")) + save_attributes_to_hdf5_group(shard_file, "layer_names", layers) if push_to_hub: self._upload_modified_files( @@ -2698,6 +2764,12 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT ): # Load from a safetensors checkpoint archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) @@ -2705,17 +2777,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT # Load from a sharded TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) is_sharded = True - elif use_safetensors is not False and os.path.isfile( - os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - ): - # Load from a sharded safetensors checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME) - is_sharded = True - raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!") + # At this stage we don't have a weight file so we will raise an error. elif use_safetensors: raise EnvironmentError( - f"Error no file named {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}. " + f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. " f"Please make sure that the model has been saved with `safe_serialization=True` or do not " f"set `use_safetensors=True`." ) @@ -2723,13 +2789,13 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) ): raise EnvironmentError( - f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " "weights." ) else: raise EnvironmentError( - f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " f"{pretrained_model_name_or_path}." ) elif os.path.isfile(pretrained_model_name_or_path): @@ -2801,9 +2867,6 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT } if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs): is_sharded = True - raise NotImplementedError( - "Support for sharded checkpoints using safetensors is coming soon!" - ) elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" @@ -2841,7 +2904,7 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. if is_sharded: # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, _ = get_checkpoint_shard_files( + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( pretrained_model_name_or_path, resolved_archive_file, cache_dir=cache_dir, @@ -2859,7 +2922,16 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT if filename == SAFE_WEIGHTS_NAME: with safe_open(resolved_archive_file, framework="tf") as f: safetensors_metadata = f.metadata() - if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: + raise OSError( + f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." + " Make sure you save your model with the `save_pretrained` method." + ) + safetensors_from_pt = safetensors_metadata.get("format") == "pt" + elif filename == SAFE_WEIGHTS_INDEX_NAME: + with safe_open(resolved_archive_file[0], framework="tf") as f: + safetensors_metadata = f.metadata() + if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: raise OSError( f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata." " Make sure you save your model with the `save_pretrained` method." @@ -2902,11 +2974,11 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT else: model.build_in_name_scope() # build the network with dummy inputs - if safetensors_from_pt: + if safetensors_from_pt and not is_sharded: from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model with safe_open(resolved_archive_file, framework="tf") as safetensors_archive: - # Load from a PyTorch checkpoint + # Load from a PyTorch safetensors checkpoint # We load in TF format here because PT weights often need to be transposed, and this is much # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times. return load_pytorch_state_dict_in_tf2_model( @@ -2919,6 +2991,19 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT ignore_mismatched_sizes=ignore_mismatched_sizes, tf_to_pt_weight_rename=tf_to_pt_weight_rename, ) + elif safetensors_from_pt: + from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model + + return load_sharded_pytorch_safetensors_in_tf2_model( + model, + resolved_archive_file, + tf_inputs=False, + allow_missing_keys=True, + output_loading_info=output_loading_info, + _prefix=load_weight_prefix, + ignore_mismatched_sizes=ignore_mismatched_sizes, + tf_to_pt_weight_rename=tf_to_pt_weight_rename, + ) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 @@ -2926,14 +3011,22 @@ class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushT if is_sharded: for file in resolved_archive_file: os.path.isfile(file), f"Error retrieving files {file}" - - missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( - model, - resolved_archive_file, - ignore_mismatched_sizes=ignore_mismatched_sizes, - _prefix=load_weight_prefix, - ) + if filename == SAFE_WEIGHTS_INDEX_NAME: + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) + else: + missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights( + model, + resolved_archive_file, + ignore_mismatched_sizes=ignore_mismatched_sizes, + _prefix=load_weight_prefix, + ) else: + # Handles both H5 and safetensors missing_keys, unexpected_keys, mismatched_keys = load_tf_weights( model, resolved_archive_file, diff --git a/tests/test_modeling_tf_utils.py b/tests/test_modeling_tf_utils.py index c2381b91a5d..ff11a8a556b 100644 --- a/tests/test_modeling_tf_utils.py +++ b/tests/test_modeling_tf_utils.py @@ -41,7 +41,13 @@ from transformers.testing_utils import ( # noqa: F401 require_torch, slow, ) -from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging +from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_INDEX_NAME, + TF2_WEIGHTS_NAME, + logging, +) logger = logging.get_logger(__name__) @@ -340,6 +346,7 @@ class TFModelUtilsTest(unittest.TestCase): for p1, p2 in zip(model.weights, ref_model.weights): assert np.allclose(p1.numpy(), p2.numpy()) + @require_safetensors def test_checkpoint_sharding_local(self): model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") @@ -389,6 +396,45 @@ class TFModelUtilsTest(unittest.TestCase): for p1, p2 in zip(model.weights, new_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + def test_safetensors_checkpoint_sharding_local(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + with tempfile.TemporaryDirectory() as tmp_dir: + # We use the same folder for various sizes to make sure a new save erases the old checkpoint. + for max_size in ["150kB", "150kiB", "200kB", "200kiB"]: + model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True) + + # Get each shard file and its size + shard_to_size = {} + for shard in os.listdir(tmp_dir): + if shard.endswith(".h5"): + shard_file = os.path.join(tmp_dir, shard) + shard_to_size[shard_file] = os.path.getsize(shard_file) + + index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME) + # Check there is an index but no regular weight file + self.assertTrue(os.path.isfile(index_file)) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME))) + + # Check the index and the shard files found match + with open(index_file, "r", encoding="utf-8") as f: + index = json.loads(f.read()) + + all_shards = set(index["weight_map"].values()) + shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")} + self.assertSetEqual(all_shards, shards_found) + + # Finally, check the model can be reloaded + new_model = TFBertModel.from_pretrained(tmp_dir) + + model.build_in_name_scope() + new_model.build_in_name_scope() + + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @slow def test_save_pretrained_signatures(self): model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") @@ -437,7 +483,26 @@ class TFModelUtilsTest(unittest.TestCase): model.save_pretrained(tmp_dir, safe_serialization=True) # No tf_model.h5 file, only a model.safetensors self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME))) + + new_model = TFBertModel.from_pretrained(tmp_dir) + + # Check models are equal + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + + @require_safetensors + def test_safetensors_sharded_save_and_load(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB") + # No tf weights or index file, only a safetensors index + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME))) + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME))) new_model = TFBertModel.from_pretrained(tmp_dir) @@ -460,6 +525,21 @@ class TFModelUtilsTest(unittest.TestCase): for p1, p2 in zip(model.weights, new_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @is_pt_tf_cross_test + def test_sharded_safetensors_save_and_load_pt_to_tf(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB") + # Check we have a safetensors shard index file + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + + new_model = TFBertModel.from_pretrained(tmp_dir) + + # Check models are equal + for p1, p2 in zip(model.weights, new_model.weights): + self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) + @require_safetensors def test_safetensors_load_from_hub(self): tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") @@ -512,9 +592,10 @@ class TFModelUtilsTest(unittest.TestCase): @require_safetensors def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self): - # This should not raise even if there are two types of sharded weights - # This should discard the safetensors weights in favor of the .h5 sharded weights - TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") + # Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present + TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True) + # Confirm that we can access the TF weights too + TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False) @require_safetensors def test_safetensors_load_from_local(self):