diff --git a/setup.py b/setup.py index f2e533ce987..0de934233dd 100644 --- a/setup.py +++ b/setup.py @@ -148,6 +148,7 @@ _deps = [ "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "sacrebleu>=1.4.12,<2.0.0", "sacremoses", + "safetensors>=0.2.1", "sagemaker>=2.31.0", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", @@ -300,6 +301,7 @@ extras["testing"] = ( "protobuf", # Can be removed once we can unpin protobuf "sacremoses", "rjieba", + "safetensors", ) + extras["retrieval"] + extras["modelcreation"] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 93c3118f691..e8c4a8939f4 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -478,6 +478,7 @@ _import_structure = { "is_psutil_available", "is_py3nvml_available", "is_pyctcdecode_available", + "is_safetensors_available", "is_scipy_available", "is_sentencepiece_available", "is_sklearn_available", @@ -3417,6 +3418,7 @@ if TYPE_CHECKING: is_psutil_available, is_py3nvml_available, is_pyctcdecode_available, + is_safetensors_available, is_scipy_available, is_sentencepiece_available, is_sklearn_available, diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index fae26de7bba..4b3c79e65bb 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -54,6 +54,7 @@ deps = { "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacremoses": "sacremoses", + "safetensors": "safetensors>=0.2.1", "sagemaker": "sagemaker>=2.31.0", "scikit-learn": "scikit-learn", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index bb35bf7c803..ec876af9e55 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -50,6 +50,8 @@ from .pytorch_utils import ( # noqa: F401 from .utils import ( DUMMY_INPUTS, FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, @@ -65,6 +67,7 @@ from .utils import ( is_bitsandbytes_available, is_offline_mode, is_remote_url, + is_safetensors_available, logging, replace_return_docstrings, ) @@ -86,6 +89,10 @@ if is_accelerate_available(): else: get_balanced_memory = None +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import save_file as safe_save_file logger = logging.get_logger(__name__) @@ -241,7 +248,9 @@ def dtype_byte_size(dtype): return bit_size // 8 -def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB"): +def shard_checkpoint( + state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME +): """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -263,6 +272,8 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): + The name of the model save file. """ max_shard_size = convert_file_size_to_int(max_shard_size) @@ -289,13 +300,16 @@ def shard_checkpoint(state_dict: Dict[str, torch.Tensor], max_shard_size: Union[ # If we only have one shard, we return it if len(sharded_state_dicts) == 1: - return {WEIGHTS_NAME: sharded_state_dicts[0]}, None + return {weights_name: sharded_state_dicts[0]}, None # Otherwise, let's build the index weight_map = {} shards = {} for idx, shard in enumerate(sharded_state_dicts): - shard_file = WEIGHTS_NAME.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) shards[shard_file] = shard for key in shard.keys(): weight_map[key] = shard_file @@ -367,6 +381,20 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): """ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. """ + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + elif metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) + return safe_load_file(checkpoint_file) try: return torch.load(checkpoint_file, map_location="cpu") except Exception as e: @@ -1468,6 +1496,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix save_function: Callable = torch.save, push_to_hub: bool = False, max_shard_size: Union[int, str] = "10GB", + safe_serialization: bool = False, **kwargs, ): """ @@ -1503,6 +1532,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -1511,6 +1543,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." ) is_main_process = kwargs.pop("save_config") + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -1560,15 +1594,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix del state_dict[ignore_key] # Shard the model if it is too big. - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) # Clean the folder from a previous save for filename in os.listdir(save_directory): full_filename = os.path.join(save_directory, filename) # If we have a shard file that is not going to be replaced, we delete it, but only from the main process # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") if ( - filename.startswith(WEIGHTS_NAME[:-4]) + filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards.keys() and is_main_process @@ -1577,12 +1613,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Save the model for shard_file, shard in shards.items(): - save_function(shard, os.path.join(save_directory, shard_file)) + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) if index is None: logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") else: - save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME) + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, save_index_file) # Save the index as well with open(save_index_file, "w", encoding="utf-8") as f: content = json.dumps(index, indent=2, sort_keys=True) + "\n" @@ -1966,6 +2008,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ): # Load from a Flax checkpoint in priority if from_flax archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME) + elif is_safetensors_available() and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) @@ -2013,6 +2066,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix filename = TF2_WEIGHTS_NAME elif from_flax: filename = FLAX_WEIGHTS_NAME + elif is_safetensors_available(): + filename = SAFE_WEIGHTS_NAME else: filename = WEIGHTS_NAME @@ -2033,8 +2088,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = WEIGHTS_NAME + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **cached_file_kwargs + ) if resolved_archive_file is None and filename == WEIGHTS_NAME: # Maybe the checkpoint is sharded, we try to grab the index name in this case. resolved_archive_file = cached_file( diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 65c15fbd967..7e3242e94c9 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -60,6 +60,7 @@ from .utils import ( is_pytesseract_available, is_pytorch_quantization_available, is_rjieba_available, + is_safetensors_available, is_scatter_available, is_scipy_available, is_sentencepiece_available, @@ -264,6 +265,13 @@ def require_accelerate(test_case): return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) +def require_safetensors(test_case): + """ + Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed. + """ + return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) + + def require_rjieba(test_case): """ Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 7f3f704ac4a..fdd1c376dab 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -111,6 +111,7 @@ from .import_utils import ( is_pytorch_quantization_available, is_rjieba_available, is_sacremoses_available, + is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_scatter_available, @@ -156,6 +157,8 @@ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF_WEIGHTS_NAME = "model.ckpt" FLAX_WEIGHTS_NAME = "flax_model.msgpack" FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" +SAFE_WEIGHTS_NAME = "model.safetensors" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" CONFIG_NAME = "config.json" FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" MODEL_CARD_NAME = "modelcard.json" diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 16616e0772d..81b7c478c1b 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -533,6 +533,10 @@ def is_accelerate_available(): return importlib.util.find_spec("accelerate") is not None +def is_safetensors_available(): + return importlib.util.find_spec("safetensors") is not None + + def is_tokenizers_available(): return importlib.util.find_spec("tokenizers") is not None diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 42ecad03c6a..0a55f1d11c3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -53,6 +53,7 @@ from transformers.testing_utils import ( is_pt_tf_cross_test, is_staging_test, require_accelerate, + require_safetensors, require_torch, require_torch_gpu, require_torch_multi_gpu, @@ -61,6 +62,8 @@ from transformers.testing_utils import ( torch_device, ) from transformers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_accelerate_available, @@ -2980,6 +2983,57 @@ class ModelUtilsTest(TestCasePlus): "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config ) + @require_safetensors + def test_safetensors_save_and_load(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + # No pytorch_model.bin file, only a model.safetensors + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME))) + + new_model = BertModel.from_pretrained(tmp_dir) + + # Check models are equal + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + @require_safetensors + def test_safetensors_load_from_hub(self): + safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors") + pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + # Check models are equal + for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + @require_safetensors + def test_safetensors_save_and_load_sharded(self): + model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB") + # No pytorch_model.bin index file, only a model.safetensors index + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME))) + self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))) + # No regular weights file + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME))) + self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME))) + + new_model = BertModel.from_pretrained(tmp_dir) + + # Check models are equal + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + + @require_safetensors + def test_safetensors_load_from_hub_sharded(self): + safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors") + pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") + + # Check models are equal + for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()): + self.assertTrue(torch.allclose(p1, p2)) + @require_torch @is_staging_test