From cf7bed98325a0be9d195cb6b66c6a0bef9fccbc8 Mon Sep 17 00:00:00 2001 From: David Xue Date: Tue, 7 May 2024 12:55:27 -0400 Subject: [PATCH] Add safetensors to model not found error msg for default use_safetensors value (#30602) * add safetensors to model not found error for default use_safetensors=None case * format code w/ ruff * fix assert true typo --- src/transformers/modeling_utils.py | 8 ++++---- tests/test_modeling_utils.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7d60380e0b7..08c2442f87c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3270,8 +3270,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) else: raise EnvironmentError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}," - f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" f" {pretrained_model_name_or_path}." ) elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): @@ -3417,8 +3417,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix else: raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" - f" {FLAX_WEIGHTS_NAME}." + f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." ) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 16d8e9e1293..f98e1a2a239 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1001,6 +1001,26 @@ class ModelUtilsTest(TestCasePlus): self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files)) self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files)) + # test no model file found when use_safetensors=None (default when safetensors package available) + with self.assertRaises(OSError) as missing_model_file_error: + BertModel.from_pretrained("hf-internal-testing/config-no-model") + + self.assertTrue( + "does not appear to have a file named pytorch_model.bin, model.safetensors," + in str(missing_model_file_error.exception) + ) + + with self.assertRaises(OSError) as missing_model_file_error: + with tempfile.TemporaryDirectory() as tmp_dir: + with open(os.path.join(tmp_dir, "config.json"), "w") as f: + f.write("{}") + f.close() + BertModel.from_pretrained(tmp_dir) + + self.assertTrue( + "Error no file named pytorch_model.bin, model.safetensors" in str(missing_model_file_error.exception) + ) + @require_safetensors def test_safetensors_save_and_load(self): model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")