Make Transformers use cache files when hf.co is down (#16362)

* Make Transformers use cache files when hf.co is down

* Fix tests

* Was there a random circleCI failure?

* Isolate patches

* Style

* Comment out the failure since it doesn't fail anymore

* Better comment
This commit is contained in:
Sylvain Gugger 2022-03-23 15:56:49 -04:00 committed by GitHub
parent 8a69e023bf
commit c595b6e6a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 148 additions and 35 deletions

View File

@ -620,12 +620,16 @@ class PretrainedConfig(PushToHubMixin):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}." f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
) )
except HTTPError: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
f"{pretrained_model_name_or_path} is not the path to a directory containing a {configuration_file} " )
"file.\nCheckout your internet connection or see how to run the library in offline mode at " except ValueError:
"'https://huggingface.co/docs/transformers/installation#offline-mode'." raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a "
"{configuration_file} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
raise EnvironmentError( raise EnvironmentError(

View File

@ -427,10 +427,14 @@ class FeatureExtractionMixin(PushToHubMixin):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}." f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
) )
except HTTPError: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
f"{pretrained_model_name_or_path} is not the path to a directory conaining a " )
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a "
f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in " f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." "offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )

View File

@ -523,11 +523,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
) )
except HTTPError: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named " f"{err}"
f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" )
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at " "Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'." "'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )

View File

@ -1678,11 +1678,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
) )
except HTTPError: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named " f"{err}"
f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n" )
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at " "Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'." "'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )

View File

@ -1409,11 +1409,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}." f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
) )
except HTTPError: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like " f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named " f"{err}"
f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n" )
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at " "Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'." "'https://huggingface.co/docs/transformers/installation#offline-mode'."
) )

View File

@ -31,8 +31,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc
import numpy as np import numpy as np
from packaging import version from packaging import version
from requests import HTTPError
from . import __version__ from . import __version__
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .utils import ( from .utils import (
@ -1751,12 +1749,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.") logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
except HTTPError as err: except ValueError:
if "404 Client Error" in str(err): logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
logger.debug(f"Connection problem to access {file_path}.") resolved_vocab_files[file_id] = None
resolved_vocab_files[file_id] = None
else:
raise err
if len(unresolved_files) > 0: if len(unresolved_files) > 0:
logger.info( logger.info(

View File

@ -498,10 +498,17 @@ def get_from_cache(
# between the HEAD and the GET (unlikely, but hey). # between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399: if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"] url_to_download = r.headers["Location"]
except (requests.exceptions.SSLError, requests.exceptions.ProxyError): except (
requests.exceptions.SSLError,
requests.exceptions.ProxyError,
RepositoryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError,
):
# Actually raise for those subclasses of ConnectionError # Actually raise for those subclasses of ConnectionError
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
raise raise
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout): except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# Otherwise, our Internet connection is down. # Otherwise, our Internet connection is down.
# etag is None # etag is None
pass pass

View File

@ -20,7 +20,7 @@ import shutil
import sys import sys
import tempfile import tempfile
import unittest import unittest
import unittest.mock import unittest.mock as mock
from pathlib import Path from pathlib import Path
from huggingface_hub import Repository, delete_repo, login from huggingface_hub import Repository, delete_repo, login
@ -304,6 +304,22 @@ class ConfigTestUtils(unittest.TestCase):
f"pick another value for them: {', '.join(keys_with_defaults)}." f"pick another value for them: {', '.join(keys_with_defaults)}."
) )
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
class ConfigurationVersioningTest(unittest.TestCase): class ConfigurationVersioningTest(unittest.TestCase):
def test_local_versioning(self): def test_local_versioning(self):

View File

@ -19,6 +19,7 @@ import os
import sys import sys
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
from pathlib import Path from pathlib import Path
from huggingface_hub import Repository, delete_repo, login from huggingface_hub import Repository, delete_repo, login
@ -116,6 +117,23 @@ class FeatureExtractionSavingTestMixin:
self.assertIsNotNone(feat_extract) self.assertIsNotNone(feat_extract)
class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()
@is_staging_test @is_staging_test
class FeatureExtractorPushToHubTester(unittest.TestCase): class FeatureExtractorPushToHubTester(unittest.TestCase):
@classmethod @classmethod

View File

@ -23,6 +23,7 @@ import random
import sys import sys
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -2272,6 +2273,22 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@require_torch @require_torch
@is_staging_test @is_staging_test

View File

@ -21,6 +21,7 @@ import os
import random import random
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
from importlib import import_module from importlib import import_module
from typing import List, Tuple from typing import List, Tuple
@ -1555,6 +1556,22 @@ class UtilsFunctionsTest(unittest.TestCase):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
# tests whether the unpack_inputs function behaves as expected # tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self): def test_unpack_inputs(self):
class DummyModel: class DummyModel:

View File

@ -24,6 +24,7 @@ import shutil
import sys import sys
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock
from collections import OrderedDict from collections import OrderedDict
from itertools import takewhile from itertools import takewhile
from pathlib import Path from pathlib import Path
@ -3742,6 +3743,24 @@ class TokenizerTesterMixin:
self.rust_tokenizer_class.from_pretrained(tmp_dir_2) self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
class TokenizerUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@is_staging_test @is_staging_test
class TokenizerPushToHubTester(unittest.TestCase): class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]

View File

@ -59,10 +59,10 @@ socket.socket = offline_socket
# next emulate no network # next emulate no network
cmd = [sys.executable, "-c", "\n".join([load, mock, run])] cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
# should normally fail as it will fail to lookup the model files w/o the network # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
env["TRANSFORMERS_OFFLINE"] = "0" # env["TRANSFORMERS_OFFLINE"] = "0"
result = subprocess.run(cmd, env=env, check=False, capture_output=True) # result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 1, result.stderr) # self.assertEqual(result.returncode, 1, result.stderr)
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1" env["TRANSFORMERS_OFFLINE"] = "1"