Add local_files_only parameter to pretrained items (#2930)

* Add disable_outgoing to pretrained items

Setting disable_outgoing=True disables outgonig traffic:
- etags are not looked up
- models are not downloaded

* parameter name change

* Remove forgotten print
This commit is contained in:
Bram Vanroy 2020-02-24 20:58:15 +01:00 committed by GitHub
parent 286d1ec746
commit a143d9479e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 13 deletions

View File

@ -198,6 +198,7 @@ class PretrainedConfig(object):
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
if pretrained_config_archive_map is None: if pretrained_config_archive_map is None:
pretrained_config_archive_map = cls.pretrained_config_archive_map pretrained_config_archive_map = cls.pretrained_config_archive_map
@ -219,6 +220,7 @@ class PretrainedConfig(object):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only,
) )
# Load config dict # Load config dict
if resolved_config_file is None: if resolved_config_file is None:

View File

@ -214,6 +214,7 @@ def cached_path(
user_agent=None, user_agent=None,
extract_compressed_file=False, extract_compressed_file=False,
force_extract=False, force_extract=False,
local_files_only=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
@ -250,6 +251,7 @@ def cached_path(
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
user_agent=user_agent, user_agent=user_agent,
local_files_only=local_files_only,
) )
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
@ -378,7 +380,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
def get_from_cache( def get_from_cache(
url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None url,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent=None,
local_files_only=False,
) -> Optional[str]: ) -> Optional[str]:
""" """
Given a URL, look for the corresponding file in the local cache. Given a URL, look for the corresponding file in the local cache.
@ -395,18 +404,19 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
# Get eTag to add to filename, if it exists. etag = None
if url.startswith("s3://"): if not local_files_only:
etag = s3_etag(url, proxies=proxies) # Get eTag to add to filename, if it exists.
else: if url.startswith("s3://"):
try: etag = s3_etag(url, proxies=proxies)
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) else:
if response.status_code != 200: try:
etag = None response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
else: if response.status_code == 200:
etag = response.headers.get("ETag") etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout): except (EnvironmentError, requests.exceptions.Timeout):
etag = None # etag is already None
pass
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
@ -427,6 +437,15 @@ def get_from_cache(
if len(matching_files) > 0: if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1]) return os.path.join(cache_dir, matching_files[-1])
else: else:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
raise ValueError(
"Cannot find the requested files in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
return None return None
# From now on, etag is not None. # From now on, etag is not None.

View File

@ -376,6 +376,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
@ -388,6 +389,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only,
**kwargs, **kwargs,
) )
else: else:
@ -435,6 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only,
) )
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:

View File

@ -395,6 +395,7 @@ class PreTrainedTokenizer(object):
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
@ -462,6 +463,7 @@ class PreTrainedTokenizer(object):
force_download=force_download, force_download=force_download,
proxies=proxies, proxies=proxies,
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only,
) )
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models: