[core] Large/full refactor of from_pretrained (#36033)

* squash everything together
start to simplify inner logic

Update modeling_utils.py

Update modeling_utils.py

Update modeling_utils.py

Update modeling_utils.py

continue refactor

fix

small fixes

add type hints/docstring

Update modeling_utils.py

remove _fast_init

keep improving

Update modeling_utils.py

Update modeling_utils.py

new first tp loading version

style

fix weird in-place op

trigger CIs

Update modeling_utils.py

much clearer renaming of keys

fix

update

Update test_modeling_common.py

trigger CIs

update

update

style

Update modeling_utils.py

Update modeling_utils.py

Update modeling_utils.py

fix

fast download first prototype

remove old function

remove old functions

Remove unused function and move back _get_tp_registry

fix tp plan registry

simplify

CIs

Update hub.py

Update modeling_utils.py

simplify

simplify renaming logic

remove unused check

add sanity check back (a test depends on it)

Update modeling_utils.py

finalize sound renaming logic

style

add forgotten check

Update modeling_utils.py

add key_mapping keyword

style

Update modeling_utils.py

add comment

minor updates

minor change for clarity

fix small prefix issue and simplify

style

trigger CIs

typo fix

Post rebase fix

post rebase cleanup

simplify tp

typo

oupsi

typo

correctly escape

improvements based on Marc's review

finalize Marc's review comments

 squash everything

* improve

* Update modeling_utils.py

* Update modeling_utils.py

* fix

* Update modeling_utils.py

* Update modeling_utils.py

* style

* Update modeling_utils.py

* simplify

* style

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* fix dtype issue

* Update modeling_utils.py

* style

* remove test that does not make sense

* style

* small fixes

* style

* fix

* cleanup after rebase

* style

* typo

* escape

* tp for task specific top modules

* Update modeling_utils.py

* Update modeling_utils.py

* fix allocation

* CIs

* CIs

* CIs

* improve docstring

* CIs

* Update modeling_utils.py

* fix
This commit is contained in:
Cyril Vallez 2025-03-12 13:39:25 +01:00 committed by GitHub
parent 7652804d23
commit 071a161d3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1525 additions and 1542 deletions

View File

@ -71,7 +71,6 @@ from .utils import (
copy_func,
default_cache_path,
define_sagemaker_information,
get_file_from_repo,
get_torch_version,
has_file,
http_user_agent,

View File

@ -306,7 +306,7 @@ def deepspeed_config():
return None
def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
def _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_to_params_buffers=False):
"""
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API.
@ -349,7 +349,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, a
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
load(model_to_load, state_dict, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
@ -220,7 +220,7 @@ def get_feature_extractor_config(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
resolved_config_file = get_file_from_repo(
resolved_config_file = cached_file(
pretrained_model_name_or_path,
FEATURE_EXTRACTOR_NAME,
cache_dir=cache_dir,
@ -230,6 +230,9 @@ def get_feature_extractor_config(
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
logger.info(

View File

@ -29,7 +29,7 @@ from ...image_processing_utils_fast import BaseImageProcessorFast
from ...utils import (
CONFIG_NAME,
IMAGE_PROCESSOR_NAME,
get_file_from_repo,
cached_file,
is_timm_config_dict,
is_timm_local_checkpoint,
is_torchvision_available,
@ -288,7 +288,7 @@ def get_image_processor_config(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
resolved_config_file = get_file_from_repo(
resolved_config_file = cached_file(
pretrained_model_name_or_path,
IMAGE_PROCESSOR_NAME,
cache_dir=cache_dir,
@ -298,6 +298,9 @@ def get_image_processor_config(
token=token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
logger.info(

View File

@ -28,7 +28,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin
from ...image_processing_utils import ImageProcessingMixin
from ...processing_utils import ProcessorMixin
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
CONFIG_MAPPING_NAMES,
@ -254,15 +254,21 @@ class AutoProcessor:
processor_auto_map = None
# First, let's see if we have a processor or preprocessor config.
# Filter the kwargs for `get_file_from_repo`.
get_file_from_repo_kwargs = {
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
# Filter the kwargs for `cached_file`.
cached_file_kwargs = {
key: kwargs[key] for key in inspect.signature(cached_file).parameters.keys() if key in kwargs
}
# We don't want to raise
cached_file_kwargs.update(
{
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_raise_exceptions_for_connection_errors": False,
}
)
# Let's start by checking whether the processor class is saved in a processor config
processor_config_file = get_file_from_repo(
pretrained_model_name_or_path, PROCESSOR_NAME, **get_file_from_repo_kwargs
)
processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
if processor_config_file is not None:
config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
processor_class = config_dict.get("processor_class", None)
@ -271,8 +277,8 @@ class AutoProcessor:
if processor_class is None:
# If not found, let's check whether the processor class is saved in an image processor config
preprocessor_config_file = get_file_from_repo(
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs
preprocessor_config_file = cached_file(
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
)
if preprocessor_config_file is not None:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
@ -291,8 +297,8 @@ class AutoProcessor:
if processor_class is None:
# Next, let's check whether the processor class is saved in a tokenizer
tokenizer_config_file = get_file_from_repo(
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
tokenizer_config_file = cached_file(
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
)
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as reader:

View File

@ -25,7 +25,7 @@ import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...utils import logging
from ...utils.hub import get_file_from_repo
from ...utils.hub import cached_file
from ..auto import AutoTokenizer
@ -86,7 +86,7 @@ class BarkProcessor(ProcessorMixin):
"""
if speaker_embeddings_dict_path is not None:
speaker_embeddings_path = get_file_from_repo(
speaker_embeddings_path = cached_file(
pretrained_processor_name_or_path,
speaker_embeddings_dict_path,
subfolder=kwargs.pop("subfolder", None),
@ -97,6 +97,9 @@ class BarkProcessor(ProcessorMixin):
local_files_only=kwargs.pop("local_files_only", False),
token=kwargs.pop("use_auth_token", None),
revision=kwargs.pop("revision", None),
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if speaker_embeddings_path is None:
logger.warning(
@ -182,7 +185,7 @@ class BarkProcessor(ProcessorMixin):
f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]."
)
path = get_file_from_repo(
path = cached_file(
self.speaker_embeddings.get("repo_or_path", "/"),
voice_preset_paths[key],
subfolder=kwargs.pop("subfolder", None),
@ -193,6 +196,9 @@ class BarkProcessor(ProcessorMixin):
local_files_only=kwargs.pop("local_files_only", False),
token=kwargs.pop("use_auth_token", None),
revision=kwargs.pop("revision", None),
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if path is None:
raise ValueError(

View File

@ -544,7 +544,7 @@ class CvtPreTrainedModel(PreTrainedModel):
elif isinstance(module, CvtStage):
if self.config.cls_token[module.stage]:
module.cls_token.data = nn.init.trunc_normal_(
torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=self.config.initializer_range
module.cls_token.data, mean=0.0, std=self.config.initializer_range
)

View File

@ -35,7 +35,7 @@ from torch import Tensor
from vissl.models.model_helpers import get_trunk_forward_outputs
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
from transformers.utils import logging
@ -244,14 +244,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
our_model_func = RegNetModel
if "in1k" in model_name:
our_model_func = RegNetForImageClassification
our_model = our_model_func(our_config)
# place our model to the meta device (so remove all the weights)
our_model.to(torch.device("meta"))
with torch.device("meta"):
our_model = our_model_func(our_config)
logger.info("Loading state_dict in our model.")
# load state dict
state_dict_keys = our_model.state_dict().keys()
PreTrainedModel._load_pretrained_model_low_mem(
our_model, state_dict_keys, [save_directory / f"{model_name}.pth"]
state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True)
fixed_state_dict = state_dict = {our_model._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
_load_state_dict_into_meta_model(
our_model,
fixed_state_dict,
start_prefix="",
expected_keys=state_dict_keys,
)
logger.info("Finally, pushing!")
# push it to hub

View File

@ -113,7 +113,7 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
Override original method to fix state_dict keys on load for cases when weights are loaded
without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
"""
state_dict = self._fix_state_dict_keys_on_load(state_dict)
state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
return super().load_state_dict(state_dict, *args, **kwargs)
def _init_weights(self, module):

View File

@ -91,7 +91,6 @@ from .hub import (
define_sagemaker_information,
download_url,
extract_commit_hash,
get_file_from_repo,
has_file,
http_user_agent,
is_offline_mode,

View File

@ -40,6 +40,7 @@ from huggingface_hub import (
create_repo,
hf_hub_download,
hf_hub_url,
snapshot_download,
try_to_load_from_cache,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
@ -47,7 +48,6 @@ from huggingface_hub.utils import (
EntryNotFoundError,
GatedRepoError,
HfHubHTTPError,
HFValidationError,
LocalEntryNotFoundError,
OfflineModeIsEnabled,
RepositoryNotFoundError,
@ -69,7 +69,6 @@ from .import_utils import (
is_torch_available,
is_training_run_on_sagemaker,
)
from .logging import tqdm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -209,21 +208,7 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
**kwargs,
) -> Optional[str]:
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
@ -231,7 +216,6 @@ def cached_file(
Args:
path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filename (`str`):
@ -274,6 +258,94 @@ def cached_file(
Examples:
```python
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
```
"""
file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs)
file = file[0] if file is not None else file
return file
def cached_files(
path_or_repo_id: Union[str, os.PathLike],
filenames: List[str],
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_gated_repo: bool = True,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
) -> Optional[str]:
"""
Tries to locate several files in a local folder and repo, downloads and cache them if necessary.
Args:
path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filenames (`List[str]`):
The name of all the files to locate in `path_or_repo`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
Private args:
_raise_exceptions_for_gated_repo (`bool`):
if False, do not raise an exception for gated repo error but return None.
_raise_exceptions_for_missing_entries (`bool`):
if False, do not raise an exception for missing entries but return None.
_raise_exceptions_for_connection_errors (`bool`):
if False, do not raise an exception for connection errors but return None.
_commit_hash (`str`, *optional*):
passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
<Tip>
Passing `token=True` is required when you want to use a private model.
</Tip>
Returns:
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
Examples:
```python
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
@ -289,144 +361,176 @@ def cached_file(
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
# Private arguments
# _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
# None.
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
# None.
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
# None.
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if subfolder is None:
subfolder = ""
# Add folder to filenames
full_filenames = [os.path.join(subfolder, file) for file in filenames]
path_or_repo_id = str(path_or_repo_id)
full_filename = os.path.join(subfolder, filename)
if os.path.isdir(path_or_repo_id):
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]:
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
)
else:
return None
return resolved_file
existing_files = []
for filename in full_filenames:
if os.path.isdir(path_or_repo_id):
resolved_file = os.path.join(path_or_repo_id, filename)
if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
revision_ = "main" if revision is None else revision
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
)
else:
return None
existing_files.append(resolved_file)
# All files exist
if len(existing_files) == len(full_filenames):
return existing_files
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
existing_files = []
file_counter = 0
if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
return resolved_file
elif not _raise_exceptions_for_missing_entries:
return None
else:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
for filename in full_filenames:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(
path_or_repo_id, filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
file_counter += 1
existing_files.append(resolved_file)
elif not _raise_exceptions_for_missing_entries:
file_counter += 1
else:
raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.")
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
if file_counter == len(full_filenames):
return existing_files if len(existing_files) > 0 else None
user_agent = http_user_agent(user_agent)
# download the files if needed
try:
# Load from URL or cache if already cached
resolved_file = hf_hub_download(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
if len(full_filenames) == 1:
# This is slightly better for only 1 file
hf_hub_download(
path_or_repo_id,
filenames[0],
subfolder=None if len(subfolder) == 0 else subfolder,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
else:
snapshot_download(
path_or_repo_id,
allow_patterns=full_filenames,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except Exception as e:
# We cannot recover from them
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
raise EnvironmentError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
"`token=<your_token>`"
) from e
elif isinstance(e, RevisionNotFoundError):
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
# Now we try to recover if we can find all files correctly in the cache
resolved_files = [
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
]
if all(file is not None for file in resolved_files):
return resolved_files
# Raise based on the flags. Note that we will raise for missing entries at the very end, even when
# not entering this Except block, as it may also happen when `snapshot_download` does not raise
if isinstance(e, GatedRepoError):
if not _raise_exceptions_for_gated_repo:
return None
raise EnvironmentError(
"You are trying to access a gated repo.\nMake sure to have access to it at "
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
) from e
elif isinstance(e, LocalEntryNotFoundError):
if not _raise_exceptions_for_connection_errors:
return None
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
elif _raise_exceptions_for_missing_entries:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) from e
# snapshot_download will not raise EntryNotFoundError, but hf_hub_download can. If this is the case, it will be treated
# later on anyway and re-raised if needed
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
if not _raise_exceptions_for_connection_errors:
return None
raise EnvironmentError(
f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}"
)
resolved_files = [
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
]
# If there are any missing file and the flag is active, raise
if any(file is None for file in resolved_files) and _raise_exceptions_for_missing_entries:
missing_entries = [original for original, resolved in zip(full_filenames, resolved_files) if resolved is None]
# Last escape
if len(resolved_files) == 1 and missing_entries[0] == os.path.join(subfolder, "config.json"):
return None
# Now we raise for missing entries
revision_ = "main" if revision is None else revision
msg = f"a file named {missing_entries[0]}" if len(missing_entries) == 1 else f"files named {*missing_entries,}"
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have {msg}. Checkout 'https://huggingface.co/{path_or_repo_id}/tree/{revision_}'"
"for available files."
)
except GatedRepoError as e:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_gated_repo:
return resolved_file
raise EnvironmentError(
"You are trying to access a gated repo.\nMake sure to have access to it at "
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
) from e
except RepositoryNotFoundError as e:
raise EnvironmentError(
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
"`token=<your_token>`"
) from e
except RevisionNotFoundError as e:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
) from e
except LocalEntryNotFoundError as e:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if (
resolved_file is not None
or not _raise_exceptions_for_missing_entries
or not _raise_exceptions_for_connection_errors
):
return resolved_file
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
) from e
except EntryNotFoundError as e:
if not _raise_exceptions_for_missing_entries:
return None
if revision is None:
revision = "main"
if filename in ["config.json", f"{subfolder}/config.json"]:
return None
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
) from e
except HTTPError as err:
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
if resolved_file is not None or not _raise_exceptions_for_connection_errors:
return resolved_file
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
except HFValidationError as e:
raise EnvironmentError(
f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub."
) from e
return resolved_file
# Remove potential missing entries (we can silently remove them at this point based on the flags)
resolved_files = [file for file in resolved_files if file is not None]
# Return `None` if the list is empty, coherent with other Exception when the flag is not active
resolved_files = None if len(resolved_files) == 0 else resolved_files
return resolved_files
# TODO: deprecate `get_file_from_repo` or document it differently?
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
# TODO cyril: Deprecated and should be removed in 4.51
def get_file_from_repo(
path_or_repo: Union[str, os.PathLike],
filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
resume_download: Optional[bool] = None,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
**deprecated_kwargs,
*args,
**kwargs,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
@ -483,30 +587,15 @@ def get_file_from_repo(
tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
token = use_auth_token
logger.warning(
"`get_file_from_repo` is deprecated and will be removed in version 4.51. Use `cached_file` instead."
)
return cached_file(
path_or_repo_id=path_or_repo,
filename=filename,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
*args,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**kwargs,
)
@ -1023,45 +1112,22 @@ def get_checkpoint_shard_files(
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
# Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
# downloaded (if interrupted).
last_shard = try_to_load_from_cache(
pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
# At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache,
# or download the files
cached_filenames = cached_files(
pretrained_model_name_or_path,
shard_filenames,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
)
show_progress_bar = last_shard is None or force_download
for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
try:
# Load from URL
cached_filename = cached_file(
pretrained_model_name_or_path,
shard_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=_commit_hash,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
"required according to the checkpoint index."
)
except HTTPError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
" again after checking your internet connection."
)
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata

View File

@ -2368,10 +2368,9 @@ class ModelTesterMixin:
safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
prefix = f"{model_reloaded.base_model_prefix}."
params = dict(model_reloaded.named_parameters())
params.update(dict(model_reloaded.named_buffers()))
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
param_names = set(params.keys())
missing_keys = set(infos["missing_keys"])
@ -2383,9 +2382,8 @@ class ModelTesterMixin:
ptrs[id_tensor_storage(tensor)].append(name)
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
for group in tied_params:
group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
# We remove the group from extra_missing if not all weights from group are in it
if len(group - extra_missing) > 0:
if len(set(group) - extra_missing) > 0:
extra_missing = extra_missing - set(group)
self.assertEqual(
@ -2399,15 +2397,14 @@ class ModelTesterMixin:
# Remove nonpersistent buffers from missed_missing
buffers = [n for n, _ in model_reloaded.named_buffers()]
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
nonpersistent_buffers = {
k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
}
missed_missing = missed_missing - nonpersistent_buffers
if model_reloaded._keys_to_ignore_on_load_missing is None:
expected_missing = set()
else:
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
expected_missing = set()
for pattern in model_reloaded._keys_to_ignore_on_load_missing:
expected_missing.update({k for k in param_names if re.search(pattern, k) is not None})
self.assertEqual(
missed_missing,
expected_missing,

View File

@ -28,7 +28,6 @@ from transformers.utils import (
TRANSFORMERS_CACHE,
WEIGHTS_NAME,
cached_file,
get_file_from_repo,
has_file,
)
@ -87,14 +86,8 @@ class GetFromCacheTests(unittest.TestCase):
path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
self.assertIsNone(path)
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
# Under the mock environment, hf_hub_download will always raise an HTTPError
with mock.patch("transformers.utils.hub.hf_hub_download", side_effect=HTTPError) as mock_head:
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
self.assertIsNone(path)
# This check we did call the fake head request
@ -117,18 +110,45 @@ class GetFromCacheTests(unittest.TestCase):
assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
def test_get_file_from_repo_distant(self):
# `get_file_from_repo` returns None if the file does not exist
self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt"))
# should return None if the file does not exist
self.assertIsNone(
cached_file(
"google-bert/bert-base-cased",
"ahah.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
# The function raises if the repository does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
get_file_from_repo("bert-base-case", CONFIG_NAME)
cached_file(
"bert-base-case",
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
# The function raises if the revision does not exist.
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha")
cached_file(
"google-bert/bert-base-cased",
CONFIG_NAME,
revision="ahaha",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME)
resolved_file = cached_file(
"google-bert/bert-base-cased",
CONFIG_NAME,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
# The name is the cached name which is not very easy to test, so instead we load the content.
config = json.loads(open(resolved_file, "r").read())
self.assertEqual(config["hidden_size"], 768)
@ -137,9 +157,26 @@ class GetFromCacheTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
filename = Path(tmp_dir) / "a.txt"
filename.touch()
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
self.assertEqual(
cached_file(
tmp_dir,
"a.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
),
str(filename),
)
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
self.assertIsNone(
cached_file(
tmp_dir,
"b.txt",
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
)
def test_get_file_gated_repo(self):
"""Test download file from a gated repo fails with correct message when not authenticated."""

View File

@ -14,7 +14,6 @@
# limitations under the License.
import copy
import glob
import itertools
import json
import os
import os.path
@ -525,13 +524,12 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
# TODO @ARTHURZUCKER FIX THIS
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
# LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
# model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
# self.assertEqual(model.language_model.dtype, torch.float32)
# self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
# self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
@ -540,20 +538,6 @@ class ModelUtilsTest(TestCasePlus):
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
)
@require_torch
@unittest.skip("Broken by @arthurzucker because the fix was not correct. Knowing the context is super hard")
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):
with torch.device("meta"):
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
return all(value.device.type == "meta" for value in model.state_dict().values())
model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
dtypes = (None, "auto", torch.float16)
for model_id, dtype in itertools.product(model_ids, dtypes):
self.assertTrue(is_on_meta(model_id, dtype))
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument