mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add force_download option to from_pretrained methods
This commit is contained in:
parent
c589862b78
commit
fecaed0ed4
@ -93,12 +93,15 @@ def filename_to_url(filename, cache_dir=None):
|
||||
return url, etag
|
||||
|
||||
|
||||
def cached_path(url_or_filename, cache_dir=None):
|
||||
def cached_path(url_or_filename, cache_dir=None, force_download=False):
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
return the path to the cached file. If it's already a local path,
|
||||
make sure the file exists and then return the path.
|
||||
Args:
|
||||
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
||||
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_TRANSFORMERS_CACHE
|
||||
@ -111,7 +114,7 @@ def cached_path(url_or_filename, cache_dir=None):
|
||||
|
||||
if parsed.scheme in ('http', 'https', 's3'):
|
||||
# URL, so get it from the cache (downloading if necessary)
|
||||
return get_from_cache(url_or_filename, cache_dir)
|
||||
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download)
|
||||
elif os.path.exists(url_or_filename):
|
||||
# File, and it exists.
|
||||
return url_or_filename
|
||||
@ -184,7 +187,7 @@ def http_get(url, temp_file):
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(url, cache_dir=None):
|
||||
def get_from_cache(url, cache_dir=None, force_download=False):
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
@ -227,11 +230,11 @@ def get_from_cache(url, cache_dir=None):
|
||||
if matching_files:
|
||||
cache_path = os.path.join(cache_dir, matching_files[-1])
|
||||
|
||||
if not os.path.exists(cache_path):
|
||||
if not os.path.exists(cache_path) or force_download:
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
# GET file object
|
||||
if url.startswith("s3://"):
|
||||
|
@ -125,6 +125,9 @@ class PretrainedConfig(object):
|
||||
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
|
||||
- Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
return_unused_kwargs: (`optional`) bool:
|
||||
|
||||
- If False, then this function returns just the final configuration object.
|
||||
@ -146,6 +149,7 @@ class PretrainedConfig(object):
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
|
||||
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
@ -156,7 +160,7 @@ class PretrainedConfig(object):
|
||||
config_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
|
||||
logger.error(
|
||||
@ -400,6 +404,9 @@ class PreTrainedModel(nn.Module):
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
@ -424,6 +431,7 @@ class PreTrainedModel(nn.Module):
|
||||
state_dict = kwargs.pop('state_dict', None)
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
from_tf = kwargs.pop('from_tf', False)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
@ -431,6 +439,7 @@ class PreTrainedModel(nn.Module):
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
@ -453,7 +462,7 @@ class PreTrainedModel(nn.Module):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error(
|
||||
|
@ -193,6 +193,9 @@ class PreTrainedTokenizer(object):
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the vocabulary files and override the cached versions if they exists.
|
||||
|
||||
inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
|
||||
|
||||
kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~pytorch_transformers.PreTrainedTokenizer` for details.
|
||||
@ -223,6 +226,7 @@ class PreTrainedTokenizer(object):
|
||||
@classmethod
|
||||
def _from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
|
||||
cache_dir = kwargs.pop('cache_dir', None)
|
||||
force_download = kwargs.pop('force_download', False)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
vocab_files = {}
|
||||
@ -283,7 +287,7 @@ class PreTrainedTokenizer(object):
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir)
|
||||
resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
logger.error("Couldn't reach server to download vocabulary.")
|
||||
|
Loading…
Reference in New Issue
Block a user