Add filename to info diaplyed when downloading things in from_pretrained (#18099)

This commit is contained in:
Sylvain Gugger 2022-07-11 12:45:06 -04:00 committed by GitHub
parent 6c8017a5c8
commit b1b8222d80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -422,7 +422,14 @@ def _raise_for_status(response: Response):
response.raise_for_status()
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
def http_get(
url: str,
temp_file: BinaryIO,
proxies=None,
resume_size=0,
headers: Optional[Dict[str, str]] = None,
file_name: Optional[str] = None,
):
"""
Download remote file. Do not gobble up errors.
"""
@ -441,7 +448,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
unit_divisor=1024,
total=total,
initial=resume_size,
desc="Downloading",
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
@ -591,7 +598,16 @@ def get_from_cache(
with temp_file_manager() as temp_file:
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
# The url_to_download might be messy, so we extract the file name from the original url.
file_name = url.split("/")[-1]
http_get(
url_to_download,
temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
file_name=file_name,
)
logger.info(f"storing {url} in cache at {cache_path}")
os.replace(temp_file.name, cache_path)