diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 4e46298e28a..6de0b3a2461 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -422,7 +422,14 @@ def _raise_for_status(response: Response): response.raise_for_status() -def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None): +def http_get( + url: str, + temp_file: BinaryIO, + proxies=None, + resume_size=0, + headers: Optional[Dict[str, str]] = None, + file_name: Optional[str] = None, +): """ Download remote file. Do not gobble up errors. """ @@ -441,7 +448,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers unit_divisor=1024, total=total, initial=resume_size, - desc="Downloading", + desc=f"Downloading {file_name}" if file_name is not None else "Downloading", ) for chunk in r.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks @@ -591,7 +598,16 @@ def get_from_cache( with temp_file_manager() as temp_file: logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}") - http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers) + # The url_to_download might be messy, so we extract the file name from the original url. + file_name = url.split("/")[-1] + http_get( + url_to_download, + temp_file, + proxies=proxies, + resume_size=resume_size, + headers=headers, + file_name=file_name, + ) logger.info(f"storing {url} in cache at {cache_path}") os.replace(temp_file.name, cache_path)