mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
PoC to avoid using get_list_of_files
This commit is contained in:
parent
3d66f3b528
commit
cb93b7cae8
@ -21,7 +21,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
@ -36,7 +36,6 @@ from .file_utils import (
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
get_list_of_files,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
@ -46,7 +45,6 @@ from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
FULL_CONFIGURATION_FILE = "config.json"
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
@ -509,6 +507,14 @@ class PretrainedConfig(PushToHubMixin):
|
||||
assert unused_kwargs == {"foo": False}
|
||||
```"""
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "configuration_files" in config_dict:
|
||||
# We may have to load another config
|
||||
configuration_file = get_configuration_file(config_dict["configuration_files"])
|
||||
if configuration_file != CONFIG_NAME:
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, _configuration_file=configuration_file, **kwargs
|
||||
)
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warn(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
@ -525,7 +531,12 @@ class PretrainedConfig(PushToHubMixin):
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
[`PretrainedConfig`] using `from_dict`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This method is for internal use only and will only load the base configuration of the model (when several are
|
||||
available). You should always use the [~PretrainedConfig.from_pretrained`] method.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
@ -557,13 +568,7 @@ class PretrainedConfig(PushToHubMixin):
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = get_configuration_file(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
else:
|
||||
@ -841,49 +846,26 @@ class PretrainedConfig(PushToHubMixin):
|
||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||
|
||||
|
||||
def get_configuration_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
local_files_only: bool = False,
|
||||
) -> str:
|
||||
def get_configuration_file(configuration_files) -> str:
|
||||
"""
|
||||
Get the configuration file to use for this version of transformers.
|
||||
|
||||
Args:
|
||||
path_or_repo (`str` or `os.PathLike`):
|
||||
Can be either the id of a repo on huggingface.co or a path to a *directory*.
|
||||
revision(`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only rely on local files and not to attempt to download any files.
|
||||
configuration_files (`List[str]`): The list of configuration files to pick from.
|
||||
|
||||
Returns:
|
||||
`str`: The configuration file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
try:
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
return FULL_CONFIGURATION_FILE
|
||||
|
||||
configuration_files_map = {}
|
||||
for file_name in all_files:
|
||||
for file_name in configuration_files:
|
||||
search = _re_configuration_file.search(file_name)
|
||||
if search is not None:
|
||||
v = search.groups()[0]
|
||||
configuration_files_map[v] = file_name
|
||||
available_versions = sorted(configuration_files_map.keys())
|
||||
|
||||
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
||||
configuration_file = FULL_CONFIGURATION_FILE
|
||||
# Defaults to CONFIG_NAME and then try to look at some newer versions.
|
||||
configuration_file = CONFIG_NAME
|
||||
transformers_version = version.parse(__version__)
|
||||
for v in available_versions:
|
||||
if version.parse(v) <= transformers_version:
|
||||
|
Loading…
Reference in New Issue
Block a user