Make Transformers use cache files when hf.co is down (#16362)

* Make Transformers use cache files when hf.co is down

* Fix tests

* Was there a random circleCI failure?

* Isolate patches

* Style

* Comment out the failure since it doesn't fail anymore

* Better comment
This commit is contained in:
Sylvain Gugger 2022-03-23 15:56:49 -04:00 committed by GitHub
parent 8a69e023bf
commit c595b6e6a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 148 additions and 35 deletions

View File

@ -620,12 +620,16 @@ class PretrainedConfig(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory containing a {configuration_file} "
"file.\nCheckout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a "
"{configuration_file} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(

View File

@ -427,10 +427,14 @@ class FeatureExtractionMixin(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a "
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory containing a "
f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in "
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)

View File

@ -523,11 +523,16 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)

View File

@ -1678,11 +1678,16 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)

View File

@ -1409,11 +1409,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
)
except HTTPError:
except HTTPError as err:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n"
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
"We couldn't connect to 'https://huggingface.co/' to load this model, couldn't find it in the cached "
f"files and it looks like {pretrained_model_name_or_path} is not the path to a directory "
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
f"{FLAX_WEIGHTS_NAME}.\n"
"Checkout your internet connection or see how to run the library in offline mode at "
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
)

View File

@ -31,8 +31,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc
import numpy as np
from packaging import version
from requests import HTTPError
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
@ -1751,12 +1749,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
resolved_vocab_files[file_id] = None
except HTTPError as err:
if "404 Client Error" in str(err):
logger.debug(f"Connection problem to access {file_path}.")
resolved_vocab_files[file_id] = None
else:
raise err
except ValueError:
logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
resolved_vocab_files[file_id] = None
if len(unresolved_files) > 0:
logger.info(

View File

@ -498,10 +498,17 @@ def get_from_cache(
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
except (
requests.exceptions.SSLError,
requests.exceptions.ProxyError,
RepositoryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError,
):
# Actually raise for those subclasses of ConnectionError
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
raise
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# Otherwise, our Internet connection is down.
# etag is None
pass

View File

@ -20,7 +20,7 @@ import shutil
import sys
import tempfile
import unittest
import unittest.mock
import unittest.mock as mock
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
@ -304,6 +304,22 @@ class ConfigTestUtils(unittest.TestCase):
f"pick another value for them: {', '.join(keys_with_defaults)}."
)
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
class ConfigurationVersioningTest(unittest.TestCase):
def test_local_versioning(self):

View File

@ -19,6 +19,7 @@ import os
import sys
import tempfile
import unittest
import unittest.mock as mock
from pathlib import Path
from huggingface_hub import Repository, delete_repo, login
@ -116,6 +117,23 @@ class FeatureExtractionSavingTestMixin:
self.assertIsNotNone(feat_extract)
class FeatureExtractorUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()
@is_staging_test
class FeatureExtractorPushToHubTester(unittest.TestCase):
@classmethod

View File

@ -23,6 +23,7 @@ import random
import sys
import tempfile
import unittest
import unittest.mock as mock
import warnings
from pathlib import Path
from typing import Dict, List, Tuple
@ -2272,6 +2273,22 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@require_torch
@is_staging_test

View File

@ -21,6 +21,7 @@ import os
import random
import tempfile
import unittest
import unittest.mock as mock
from importlib import import_module
from typing import List, Tuple
@ -1555,6 +1556,22 @@ class UtilsFunctionsTest(unittest.TestCase):
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel:

View File

@ -24,6 +24,7 @@ import shutil
import sys
import tempfile
import unittest
import unittest.mock as mock
from collections import OrderedDict
from itertools import takewhile
from pathlib import Path
@ -3742,6 +3743,24 @@ class TokenizerTesterMixin:
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
class TokenizerUtilTester(unittest.TestCase):
def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = []
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]

View File

@ -59,10 +59,10 @@ socket.socket = offline_socket
# next emulate no network
cmd = [sys.executable, "-c", "\n".join([load, mock, run])]
# should normally fail as it will fail to lookup the model files w/o the network
env["TRANSFORMERS_OFFLINE"] = "0"
result = subprocess.run(cmd, env=env, check=False, capture_output=True)
self.assertEqual(result.returncode, 1, result.stderr)
# Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
# env["TRANSFORMERS_OFFLINE"] = "0"
# result = subprocess.run(cmd, env=env, check=False, capture_output=True)
# self.assertEqual(result.returncode, 1, result.stderr)
# should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
env["TRANSFORMERS_OFFLINE"] = "1"