diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 9cc019b93b4..d8cd0fe3e9f 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -198,6 +198,7 @@ class PretrainedConfig(object): force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) if pretrained_config_archive_map is None: pretrained_config_archive_map = cls.pretrained_config_archive_map @@ -219,6 +220,7 @@ class PretrainedConfig(object): force_download=force_download, proxies=proxies, resume_download=resume_download, + local_files_only=local_files_only, ) # Load config dict if resolved_config_file is None: diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ddbac788c69..dfc6d1a8fef 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -214,6 +214,7 @@ def cached_path( user_agent=None, extract_compressed_file=False, force_extract=False, + local_files_only=False, ) -> Optional[str]: """ Given something that might be a URL (or might be a local path), @@ -250,6 +251,7 @@ def cached_path( proxies=proxies, resume_download=resume_download, user_agent=user_agent, + local_files_only=local_files_only, ) elif os.path.exists(url_or_filename): # File, and it exists. @@ -378,7 +380,14 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): def get_from_cache( - url, cache_dir=None, force_download=False, proxies=None, etag_timeout=10, resume_download=False, user_agent=None + url, + cache_dir=None, + force_download=False, + proxies=None, + etag_timeout=10, + resume_download=False, + user_agent=None, + local_files_only=False, ) -> Optional[str]: """ Given a URL, look for the corresponding file in the local cache. @@ -395,18 +404,19 @@ def get_from_cache( os.makedirs(cache_dir, exist_ok=True) - # Get eTag to add to filename, if it exists. - if url.startswith("s3://"): - etag = s3_etag(url, proxies=proxies) - else: - try: - response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) - if response.status_code != 200: - etag = None - else: - etag = response.headers.get("ETag") - except (EnvironmentError, requests.exceptions.Timeout): - etag = None + etag = None + if not local_files_only: + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url, proxies=proxies) + else: + try: + response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) + if response.status_code == 200: + etag = response.headers.get("ETag") + except (EnvironmentError, requests.exceptions.Timeout): + # etag is already None + pass filename = url_to_filename(url, etag) @@ -427,6 +437,15 @@ def get_from_cache( if len(matching_files) > 0: return os.path.join(cache_dir, matching_files[-1]) else: + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the cached path and outgoing traffic has been" + " disabled. To enable model look-ups and downloads online, set 'local_files_only'" + " to False." + ) return None # From now on, etag is not None. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 72c8195b0b0..c48bcec17de 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -376,6 +376,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): @@ -388,6 +389,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): force_download=force_download, resume_download=resume_download, proxies=proxies, + local_files_only=local_files_only, **kwargs, ) else: @@ -435,6 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): force_download=force_download, proxies=proxies, resume_download=resume_download, + local_files_only=local_files_only, ) except EnvironmentError: if pretrained_model_name_or_path in cls.pretrained_model_archive_map: diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 73fcb79c97f..90215778da9 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -395,6 +395,7 @@ class PreTrainedTokenizer(object): force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) s3_models = list(cls.max_model_input_sizes.keys()) vocab_files = {} @@ -462,6 +463,7 @@ class PreTrainedTokenizer(object): force_download=force_download, proxies=proxies, resume_download=resume_download, + local_files_only=local_files_only, ) except EnvironmentError: if pretrained_model_name_or_path in s3_models: