PoC to avoid using get_list_of_files

This commit is contained in:
Sylvain Gugger 2022-01-20 15:52:01 -05:00
parent 3d66f3b528
commit cb93b7cae8

View File

@ -21,7 +21,7 @@ import json
import os
import re
import warnings
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union
from packaging import version
@ -36,7 +36,6 @@ from .file_utils import (
RevisionNotFoundError,
cached_path,
copy_func,
get_list_of_files,
hf_bucket_url,
is_offline_mode,
is_remote_url,
@ -46,7 +45,6 @@ from .utils import logging
logger = logging.get_logger(__name__)
FULL_CONFIGURATION_FILE = "config.json"
_re_configuration_file = re.compile(r"config\.(.*)\.json")
@ -509,6 +507,14 @@ class PretrainedConfig(PushToHubMixin):
assert unused_kwargs == {"foo": False}
```"""
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "configuration_files" in config_dict:
# We may have to load another config
configuration_file = get_configuration_file(config_dict["configuration_files"])
if configuration_file != CONFIG_NAME:
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, _configuration_file=configuration_file, **kwargs
)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warn(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
@ -525,7 +531,12 @@ class PretrainedConfig(PushToHubMixin):
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
[`PretrainedConfig`] using `from_dict`.
<Tip warning={true}>
This method is for internal use only and will only load the base configuration of the model (when several are
available). You should always use the [~PretrainedConfig.from_pretrained`] method.
</Tip>
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
@ -557,13 +568,7 @@ class PretrainedConfig(PushToHubMixin):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
configuration_file = get_configuration_file(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
else:
@ -841,49 +846,26 @@ class PretrainedConfig(PushToHubMixin):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
def get_configuration_file(
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> str:
def get_configuration_file(configuration_files) -> str:
"""
Get the configuration file to use for this version of transformers.
Args:
path_or_repo (`str` or `os.PathLike`):
Can be either the id of a repo on huggingface.co or a path to a *directory*.
revision(`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
configuration_files (`List[str]`): The list of configuration files to pick from.
Returns:
`str`: The configuration file to use.
"""
# Inspect all files from the repo/folder.
try:
all_files = get_list_of_files(
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
)
except Exception:
return FULL_CONFIGURATION_FILE
configuration_files_map = {}
for file_name in all_files:
for file_name in configuration_files:
search = _re_configuration_file.search(file_name)
if search is not None:
v = search.groups()[0]
configuration_files_map[v] = file_name
available_versions = sorted(configuration_files_map.keys())
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
configuration_file = FULL_CONFIGURATION_FILE
# Defaults to CONFIG_NAME and then try to look at some newer versions.
configuration_file = CONFIG_NAME
transformers_version = version.parse(__version__)
for v in available_versions:
if version.parse(v) <= transformers_version: