mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add filename to info diaplyed when downloading things in from_pretrained (#18099)
This commit is contained in:
parent
6c8017a5c8
commit
b1b8222d80
@ -422,7 +422,14 @@ def _raise_for_status(response: Response):
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
||||
def http_get(
|
||||
url: str,
|
||||
temp_file: BinaryIO,
|
||||
proxies=None,
|
||||
resume_size=0,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
file_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
"""
|
||||
@ -441,7 +448,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
|
||||
unit_divisor=1024,
|
||||
total=total,
|
||||
initial=resume_size,
|
||||
desc="Downloading",
|
||||
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
|
||||
)
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
@ -591,7 +598,16 @@ def get_from_cache(
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
|
||||
|
||||
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
||||
# The url_to_download might be messy, so we extract the file name from the original url.
|
||||
file_name = url.split("/")[-1]
|
||||
http_get(
|
||||
url_to_download,
|
||||
temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
file_name=file_name,
|
||||
)
|
||||
|
||||
logger.info(f"storing {url} in cache at {cache_path}")
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
Loading…
Reference in New Issue
Block a user