mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
relax network connection requirements
This commit is contained in:
parent
fa76520240
commit
265550ec34
@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
|
||||
"""
|
||||
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
||||
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import fnmatch
|
||||
from functools import wraps
|
||||
from hashlib import sha256
|
||||
import sys
|
||||
@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
|
||||
if url.startswith("s3://"):
|
||||
etag = s3_etag(url)
|
||||
else:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
if response.status_code != 200:
|
||||
raise IOError("HEAD request failed for url {} with status code {}"
|
||||
.format(url, response.status_code))
|
||||
etag = response.headers.get("ETag")
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
if response.status_code != 200:
|
||||
etag = None
|
||||
else:
|
||||
etag = response.headers.get("ETag")
|
||||
except EnvironmentError:
|
||||
etag = None
|
||||
|
||||
if sys.version_info[0] == 2 and etag is not None:
|
||||
etag = etag.decode('utf-8')
|
||||
filename = url_to_filename(url, etag)
|
||||
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# If we don't have a connection (etag is None) and can't identify the file
|
||||
# try to get the last downloaded one
|
||||
if not os.path.exists(cache_path) and etag is None:
|
||||
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
|
||||
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
|
||||
if not os.path.exists(cache_path):
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None):
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
meta_file.write(json.dumps(meta, indent=4))
|
||||
|
||||
logger.info("removing temp file %s", temp_file.name)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase):
|
||||
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
|
||||
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
|
||||
|
||||
@pytest.mark.slow
|
||||
# @pytest.mark.slow
|
||||
def test_tokenizer_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
|
||||
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user