mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add local_files_only parameter to pretrained items (#2930)
* Add disable_outgoing to pretrained items Setting disable_outgoing=True disables outgonig traffic: - etags are not looked up - models are not downloaded * parameter name change * Remove forgotten print
This commit is contained in:
parent
286d1ec746
commit
a143d9479e
@ -198,6 +198,7 @@ class PretrainedConfig(object):
|
|||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
|
||||||
if pretrained_config_archive_map is None:
|
if pretrained_config_archive_map is None:
|
||||||
pretrained_config_archive_map = cls.pretrained_config_archive_map
|
pretrained_config_archive_map = cls.pretrained_config_archive_map
|
||||||
@ -219,6 +220,7 @@ class PretrainedConfig(object):
|
|||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
# Load config dict
|
# Load config dict
|
||||||
if resolved_config_file is None:
|
if resolved_config_file is None:
|
||||||
|
@ -214,6 +214,7 @@ def cached_path(
|
|||||||
user_agent=None,
|
user_agent=None,
|
||||||
extract_compressed_file=False,
|
extract_compressed_file=False,
|
||||||
force_extract=False,
|
force_extract=False,
|
||||||
|
local_files_only=False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Given something that might be a URL (or might be a local path),
|
Given something that might be a URL (or might be a local path),
|
||||||
@ -250,6 +251,7 @@ def cached_path(
|
|||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
elif os.path.exists(url_or_filename):
|
elif os.path.exists(url_or_filename):
|
||||||
# File, and it exists.
|
# File, and it exists.
|
||||||
@ -378,7 +380,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
|
|||||||
|
|
||||||
|
|
||||||
def get_from_cache(
|
def get_from_cache(
|
||||||
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None
|
url,
|
||||||
|
cache_dir=None,
|
||||||
|
force_download=False,
|
||||||
|
proxies=None,
|
||||||
|
etag_timeout=10,
|
||||||
|
resume_download=False,
|
||||||
|
user_agent=None,
|
||||||
|
local_files_only=False,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding file in the local cache.
|
Given a URL, look for the corresponding file in the local cache.
|
||||||
@ -395,18 +404,19 @@ def get_from_cache(
|
|||||||
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
# Get eTag to add to filename, if it exists.
|
etag = None
|
||||||
if url.startswith("s3://"):
|
if not local_files_only:
|
||||||
etag = s3_etag(url, proxies=proxies)
|
# Get eTag to add to filename, if it exists.
|
||||||
else:
|
if url.startswith("s3://"):
|
||||||
try:
|
etag = s3_etag(url, proxies=proxies)
|
||||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
else:
|
||||||
if response.status_code != 200:
|
try:
|
||||||
etag = None
|
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||||
else:
|
if response.status_code == 200:
|
||||||
etag = response.headers.get("ETag")
|
etag = response.headers.get("ETag")
|
||||||
except (EnvironmentError, requests.exceptions.Timeout):
|
except (EnvironmentError, requests.exceptions.Timeout):
|
||||||
etag = None
|
# etag is already None
|
||||||
|
pass
|
||||||
|
|
||||||
filename = url_to_filename(url, etag)
|
filename = url_to_filename(url, etag)
|
||||||
|
|
||||||
@ -427,6 +437,15 @@ def get_from_cache(
|
|||||||
if len(matching_files) > 0:
|
if len(matching_files) > 0:
|
||||||
return os.path.join(cache_dir, matching_files[-1])
|
return os.path.join(cache_dir, matching_files[-1])
|
||||||
else:
|
else:
|
||||||
|
# If files cannot be found and local_files_only=True,
|
||||||
|
# the models might've been found if local_files_only=False
|
||||||
|
# Notify the user about that
|
||||||
|
if local_files_only:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
||||||
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||||
|
" to False."
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# From now on, etag is not None.
|
# From now on, etag is not None.
|
||||||
|
@ -376,6 +376,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
@ -388,6 +389,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -435,6 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
|
@ -395,6 +395,7 @@ class PreTrainedTokenizer(object):
|
|||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
|
||||||
s3_models = list(cls.max_model_input_sizes.keys())
|
s3_models = list(cls.max_model_input_sizes.keys())
|
||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
@ -462,6 +463,7 @@ class PreTrainedTokenizer(object):
|
|||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
if pretrained_model_name_or_path in s3_models:
|
if pretrained_model_name_or_path in s3_models:
|
||||||
|
Loading…
Reference in New Issue
Block a user