file_cache has options to extract archives

This commit is contained in:
thomwolf 2020-02-07 00:03:12 +01:00
parent 2c12464a20
commit 7d99e05f76

View File

@ -8,13 +8,16 @@ import fnmatch
import json
import logging
import os
import shutil
import sys
import tarfile
import tempfile
from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256
from typing import Optional
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import boto3
import requests
@ -203,7 +206,14 @@ def filename_to_url(filename, cache_dir=None):
def cached_path(
url_or_filename, cache_dir=None, force_download=False, proxies=None, resume_download=False, user_agent=None
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent=None,
extract_compressed_file=False,
force_extract=False,
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path),
@ -215,6 +225,10 @@ def cached_path(
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletly recieved file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and overide the folder where it was extracted.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
@ -229,7 +243,7 @@ def cached_path(
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
@ -239,7 +253,7 @@ def cached_path(
)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
@ -247,6 +261,37 @@ def cached_path(
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
return output_path_extracted
return output_path
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""