add force_download option to from_pretrained methods

This commit is contained in:
thomwolf 2019-08-20 10:56:12 +02:00
parent c589862b78
commit fecaed0ed4
3 changed files with 24 additions and 8 deletions

View File

@ -93,12 +93,15 @@ def filename_to_url(filename, cache_dir=None):
return url, etag return url, etag
def cached_path(url_or_filename, cache_dir=None): def cached_path(url_or_filename, cache_dir=None, force_download=False):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path, return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path. make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_TRANSFORMERS_CACHE cache_dir = PYTORCH_TRANSFORMERS_CACHE
@ -111,7 +114,7 @@ def cached_path(url_or_filename, cache_dir=None):
if parsed.scheme in ('http', 'https', 's3'): if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary) # URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir) return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download)
elif os.path.exists(url_or_filename): elif os.path.exists(url_or_filename):
# File, and it exists. # File, and it exists.
return url_or_filename return url_or_filename
@ -184,7 +187,7 @@ def http_get(url, temp_file):
progress.close() progress.close()
def get_from_cache(url, cache_dir=None): def get_from_cache(url, cache_dir=None, force_download=False):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
@ -227,11 +230,11 @@ def get_from_cache(url, cache_dir=None):
if matching_files: if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1]) cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path): if not os.path.exists(cache_path) or force_download:
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted. # Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file: with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
# GET file object # GET file object
if url.startswith("s3://"): if url.startswith("s3://"):

View File

@ -125,6 +125,9 @@ class PretrainedConfig(object):
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
return_unused_kwargs: (`optional`) bool: return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final configuration object. - If False, then this function returns just the final configuration object.
@ -146,6 +149,7 @@ class PretrainedConfig(object):
""" """
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
@ -156,7 +160,7 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir) resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map: if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error( logger.error(
@ -400,6 +404,9 @@ class PreTrainedModel(nn.Module):
Path to a directory in which a downloaded pre-trained model Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used. configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
output_loading_info: (`optional`) boolean: output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
@ -424,6 +431,7 @@ class PreTrainedModel(nn.Module):
state_dict = kwargs.pop('state_dict', None) state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False) from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False)
output_loading_info = kwargs.pop('output_loading_info', False) output_loading_info = kwargs.pop('output_loading_info', False)
# Load config # Load config
@ -431,6 +439,7 @@ class PreTrainedModel(nn.Module):
config, model_kwargs = cls.config_class.from_pretrained( config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args, pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True, cache_dir=cache_dir, return_unused_kwargs=True,
force_download=force_download,
**kwargs **kwargs
) )
else: else:
@ -453,7 +462,7 @@ class PreTrainedModel(nn.Module):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error( logger.error(

View File

@ -193,6 +193,9 @@ class PreTrainedTokenizer(object):
cache_dir: (`optional`) string: cache_dir: (`optional`) string:
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used. Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the vocabulary files and override the cached versions if they exists.
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method. inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details. kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
@ -223,6 +226,7 @@ class PreTrainedTokenizer(object):
@classmethod @classmethod
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
cache_dir = kwargs.pop('cache_dir', None) cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
@ -283,7 +287,7 @@ class PreTrainedTokenizer(object):
if file_path is None: if file_path is None:
resolved_vocab_files[file_id] = None resolved_vocab_files[file_id] = None
else: else:
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir) resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download)
except EnvironmentError: except EnvironmentError:
if pretrained_model_name_or_path in s3_models: if pretrained_model_name_or_path in s3_models:
logger.error("Couldn't reach server to download vocabulary.") logger.error("Couldn't reach server to download vocabulary.")