relax network connection requirements

This commit is contained in:
thomwolf 2019-04-17 14:22:35 +02:00
parent fa76520240
commit 265550ec34
2 changed files with 23 additions and 8 deletions

View File

@ -5,11 +5,13 @@ Copyright by the AllenNLP authors.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import sys
import json
import logging
import os
import shutil
import tempfile
import fnmatch
from functools import wraps
from hashlib import sha256
import sys
@ -191,17 +193,30 @@ def get_from_cache(url, cache_dir=None):
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
try:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except EnvironmentError:
etag = None
if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
@ -226,8 +241,8 @@ def get_from_cache(url, cache_dir=None):
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
with open(meta_path, 'w') as meta_file:
meta_file.write(json.dumps(meta, indent=4))
logger.info("removing temp file %s", temp_file.name)

View File

@ -66,7 +66,7 @@ class GPT2TokenizationTest(unittest.TestCase):
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]: