add pathlib support for file_utils.py on python 3.5

This commit is contained in:
hzhwcmhf 2018-12-11 22:49:19 +08:00
parent bc659f86ad
commit 485adde742

View File

@ -23,8 +23,8 @@ import requests
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PYTORCH_PRETRAINED_BERT_CACHE = str(Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert')))
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
def url_to_filename(url: str, etag: str = None) -> str:
@ -45,13 +45,15 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
@ -69,7 +71,7 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
@ -80,6 +82,8 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
@ -158,13 +162,15 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close()
def get_from_cache(url: str, cache_dir: str = None) -> str:
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)