diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 76d01e0a9eb..f274d7c636d 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -620,12 +620,16 @@ class PretrainedConfig(PushToHubMixin): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}." ) - except HTTPError: + except HTTPError as err: raise EnvironmentError( - "We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " - f"{pretrained_model_name_or_path} is not the path to a directory containing a {configuration_file} " - "file.\nCheckout your internet connection or see how to run the library in offline mode at " - "'https://huggingface.co/docs/transformers/installation#offline-mode'." + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + "We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached " + f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a " + "{configuration_file} file.\nCheckout your internet connection or see how to run the library in " + "offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 4ce20671bfb..dcdb6fa01dc 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -427,10 +427,14 @@ class FeatureExtractionMixin(PushToHubMixin): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}." ) - except HTTPError: + except HTTPError as err: raise EnvironmentError( - "We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " - f"{pretrained_model_name_or_path} is not the path to a directory conaining a " + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + "We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached " + f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a " f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in " "offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." ) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index ee802b550d9..c298d6726b3 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -523,11 +523,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." ) - except HTTPError: + except HTTPError as err: raise EnvironmentError( - "We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " - f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named " - f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + "We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached " + f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory " + f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" "Checkout your internet connection or see how to run the library in offline mode at " "'https://huggingface.co/docs/transformers/installation#offline-mode'." ) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 5d8c03fd1e9..dfa341d853d 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1678,11 +1678,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." ) - except HTTPError: + except HTTPError as err: raise EnvironmentError( - "We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " - f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named " - f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + "We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached " + f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory " + f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" "Checkout your internet connection or see how to run the library in offline mode at " "'https://huggingface.co/docs/transformers/installation#offline-mode'." ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6d5e4a2946b..660310f2746 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1409,11 +1409,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix raise EnvironmentError( f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." ) - except HTTPError: + except HTTPError as err: raise EnvironmentError( - "We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " - f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named " - f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n" + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + "We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached " + f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory " + f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or " + f"{FLAX_WEIGHTS_NAME}.\n" "Checkout your internet connection or see how to run the library in offline mode at " "'https://huggingface.co/docs/transformers/installation#offline-mode'." ) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index c2a2b7bb6ca..cbf03dc9c15 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -31,8 +31,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc import numpy as np from packaging import version -from requests import HTTPError - from . import __version__ from .dynamic_module_utils import custom_object_save from .utils import ( @@ -1751,12 +1749,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.") resolved_vocab_files[file_id] = None - except HTTPError as err: - if "404 Client Error" in str(err): - logger.debug(f"Connection problem to access {file_path}.") - resolved_vocab_files[file_id] = None - else: - raise err + except ValueError: + logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.") + resolved_vocab_files[file_id] = None if len(unresolved_files) > 0: logger.info( diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 0dfd6ef96b0..a60e430b126 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -498,10 +498,17 @@ def get_from_cache( # between the HEAD and the GET (unlikely, but hey). if 300 <= r.status_code <= 399: url_to_download = r.headers["Location"] - except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + except ( + requests.exceptions.SSLError, + requests.exceptions.ProxyError, + RepositoryNotFoundError, + EntryNotFoundError, + RevisionNotFoundError, + ): # Actually raise for those subclasses of ConnectionError + # Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on. raise - except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): + except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout): # Otherwise, our Internet connection is down. # etag is None pass diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 08523de9e34..d17ff540679 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -20,7 +20,7 @@ import shutil import sys import tempfile import unittest -import unittest.mock +import unittest.mock as mock from pathlib import Path from huggingface_hub import Repository, delete_repo, login @@ -304,6 +304,22 @@ class ConfigTestUtils(unittest.TestCase): f"pick another value for them: {', '.join(keys_with_defaults)}." ) + def test_cached_files_are_used_when_internet_is_down(self): + # A mock response for an HTTP head request to emulate server down + response_mock = mock.Mock() + response_mock.status_code = 500 + response_mock.headers = [] + response_mock.raise_for_status.side_effect = HTTPError + + # Download this model to make sure it's in the cache. + _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + + # Under the mock environment we get a 500 error when trying to reach the model. + with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: + _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + # This check we did call the fake head request + mock_head.assert_called() + class ConfigurationVersioningTest(unittest.TestCase): def test_local_versioning(self): diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 3551d34d612..3f7abcaa70c 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -19,6 +19,7 @@ import os import sys import tempfile import unittest +import unittest.mock as mock from pathlib import Path from huggingface_hub import Repository, delete_repo, login @@ -116,6 +117,23 @@ class FeatureExtractionSavingTestMixin: self.assertIsNotNone(feat_extract) +class FeatureExtractorUtilTester(unittest.TestCase): + def test_cached_files_are_used_when_internet_is_down(self): + # A mock response for an HTTP head request to emulate server down + response_mock = mock.Mock() + response_mock.status_code = 500 + response_mock.headers = [] + response_mock.raise_for_status.side_effect = HTTPError + + # Download this model to make sure it's in the cache. + _ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") + # Under the mock environment we get a 500 error when trying to reach the model. + with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: + _ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2") + # This check we did call the fake head request + mock_head.assert_called() + + @is_staging_test class FeatureExtractorPushToHubTester(unittest.TestCase): @classmethod diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7cc7114e797..07cf086fae3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -23,6 +23,7 @@ import random import sys import tempfile import unittest +import unittest.mock as mock import warnings from pathlib import Path from typing import Dict, List, Tuple @@ -2272,6 +2273,22 @@ class ModelUtilsTest(TestCasePlus): for p1, p2 in zip(model.parameters(), new_model.parameters()): self.assertTrue(torch.equal(p1, p2)) + def test_cached_files_are_used_when_internet_is_down(self): + # A mock response for an HTTP head request to emulate server down + response_mock = mock.Mock() + response_mock.status_code = 500 + response_mock.headers = [] + response_mock.raise_for_status.side_effect = HTTPError + + # Download this model to make sure it's in the cache. + _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + # Under the mock environment we get a 500 error when trying to reach the model. + with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: + _ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + # This check we did call the fake head request + mock_head.assert_called() + @require_torch @is_staging_test diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 48d5b3885b2..3d2f7976cf6 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -21,6 +21,7 @@ import os import random import tempfile import unittest +import unittest.mock as mock from importlib import import_module from typing import List, Tuple @@ -1555,6 +1556,22 @@ class UtilsFunctionsTest(unittest.TestCase): tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) + def test_cached_files_are_used_when_internet_is_down(self): + # A mock response for an HTTP head request to emulate server down + response_mock = mock.Mock() + response_mock.status_code = 500 + response_mock.headers = [] + response_mock.raise_for_status.side_effect = HTTPError + + # Download this model to make sure it's in the cache. + _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + + # Under the mock environment we get a 500 error when trying to reach the model. + with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: + _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + # This check we did call the fake head request + mock_head.assert_called() + # tests whether the unpack_inputs function behaves as expected def test_unpack_inputs(self): class DummyModel: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index e58ab9a816a..f260fa71fff 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -24,6 +24,7 @@ import shutil import sys import tempfile import unittest +import unittest.mock as mock from collections import OrderedDict from itertools import takewhile from pathlib import Path @@ -3742,6 +3743,24 @@ class TokenizerTesterMixin: self.rust_tokenizer_class.from_pretrained(tmp_dir_2) +class TokenizerUtilTester(unittest.TestCase): + def test_cached_files_are_used_when_internet_is_down(self): + # A mock response for an HTTP head request to emulate server down + response_mock = mock.Mock() + response_mock.status_code = 500 + response_mock.headers = [] + response_mock.raise_for_status.side_effect = HTTPError + + # Download this model to make sure it's in the cache. + _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + + # Under the mock environment we get a 500 error when trying to reach the model. + with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head: + _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + # This check we did call the fake head request + mock_head.assert_called() + + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase): vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] diff --git a/tests/utils/test_offline.py b/tests/utils/test_offline.py index 45a12a1f2b9..33f5d4bd0a8 100644 --- a/tests/utils/test_offline.py +++ b/tests/utils/test_offline.py @@ -59,10 +59,10 @@ socket.socket = offline_socket # next emulate no network cmd = [sys.executable, "-c", "\n".join([load, mock, run])] - # should normally fail as it will fail to lookup the model files w/o the network - env["TRANSFORMERS_OFFLINE"] = "0" - result = subprocess.run(cmd, env=env, check=False, capture_output=True) - self.assertEqual(result.returncode, 1, result.stderr) + # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this. + # env["TRANSFORMERS_OFFLINE"] = "0" + # result = subprocess.run(cmd, env=env, check=False, capture_output=True) + # self.assertEqual(result.returncode, 1, result.stderr) # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files env["TRANSFORMERS_OFFLINE"] = "1"