mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[core] Large/full refactor of from_pretrained
(#36033)
* squash everything together start to simplify inner logic Update modeling_utils.py Update modeling_utils.py Update modeling_utils.py Update modeling_utils.py continue refactor fix small fixes add type hints/docstring Update modeling_utils.py remove _fast_init keep improving Update modeling_utils.py Update modeling_utils.py new first tp loading version style fix weird in-place op trigger CIs Update modeling_utils.py much clearer renaming of keys fix update Update test_modeling_common.py trigger CIs update update style Update modeling_utils.py Update modeling_utils.py Update modeling_utils.py fix fast download first prototype remove old function remove old functions Remove unused function and move back _get_tp_registry fix tp plan registry simplify CIs Update hub.py Update modeling_utils.py simplify simplify renaming logic remove unused check add sanity check back (a test depends on it) Update modeling_utils.py finalize sound renaming logic style add forgotten check Update modeling_utils.py add key_mapping keyword style Update modeling_utils.py add comment minor updates minor change for clarity fix small prefix issue and simplify style trigger CIs typo fix Post rebase fix post rebase cleanup simplify tp typo oupsi typo correctly escape improvements based on Marc's review finalize Marc's review comments squash everything * improve * Update modeling_utils.py * Update modeling_utils.py * fix * Update modeling_utils.py * Update modeling_utils.py * style * Update modeling_utils.py * simplify * style * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix dtype issue * Update modeling_utils.py * style * remove test that does not make sense * style * small fixes * style * fix * cleanup after rebase * style * typo * escape * tp for task specific top modules * Update modeling_utils.py * Update modeling_utils.py * fix allocation * CIs * CIs * CIs * improve docstring * CIs * Update modeling_utils.py * fix
This commit is contained in:
parent
7652804d23
commit
071a161d3e
@ -71,7 +71,6 @@ from .utils import (
|
||||
copy_func,
|
||||
default_cache_path,
|
||||
define_sagemaker_information,
|
||||
get_file_from_repo,
|
||||
get_torch_version,
|
||||
has_file,
|
||||
http_user_agent,
|
||||
|
@ -306,7 +306,7 @@ def deepspeed_config():
|
||||
return None
|
||||
|
||||
|
||||
def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
|
||||
def _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_to_params_buffers=False):
|
||||
"""
|
||||
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
|
||||
tensor parallelism API.
|
||||
@ -349,7 +349,7 @@ def _load_state_dict_into_zero3_model(model_to_load, state_dict, start_prefix, a
|
||||
if child is not None:
|
||||
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
|
||||
|
||||
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
|
||||
load(model_to_load, state_dict, assign_to_params_buffers=assign_to_params_buffers)
|
||||
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
||||
# it's safe to delete it.
|
||||
del state_dict
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -25,7 +25,7 @@ from typing import Dict, Optional, Union
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo, logging
|
||||
from ...utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, cached_file, logging
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@ -220,7 +220,7 @@ def get_feature_extractor_config(
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
resolved_config_file = get_file_from_repo(
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
@ -230,6 +230,9 @@ def get_feature_extractor_config(
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
|
@ -29,7 +29,7 @@ from ...image_processing_utils_fast import BaseImageProcessorFast
|
||||
from ...utils import (
|
||||
CONFIG_NAME,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
get_file_from_repo,
|
||||
cached_file,
|
||||
is_timm_config_dict,
|
||||
is_timm_local_checkpoint,
|
||||
is_torchvision_available,
|
||||
@ -288,7 +288,7 @@ def get_image_processor_config(
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
resolved_config_file = get_file_from_repo(
|
||||
resolved_config_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
IMAGE_PROCESSOR_NAME,
|
||||
cache_dir=cache_dir,
|
||||
@ -298,6 +298,9 @@ def get_image_processor_config(
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
logger.info(
|
||||
|
@ -28,7 +28,7 @@ from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...image_processing_utils import ImageProcessingMixin
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, get_file_from_repo, logging
|
||||
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, cached_file, logging
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@ -254,15 +254,21 @@ class AutoProcessor:
|
||||
processor_auto_map = None
|
||||
|
||||
# First, let's see if we have a processor or preprocessor config.
|
||||
# Filter the kwargs for `get_file_from_repo`.
|
||||
get_file_from_repo_kwargs = {
|
||||
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
|
||||
# Filter the kwargs for `cached_file`.
|
||||
cached_file_kwargs = {
|
||||
key: kwargs[key] for key in inspect.signature(cached_file).parameters.keys() if key in kwargs
|
||||
}
|
||||
# We don't want to raise
|
||||
cached_file_kwargs.update(
|
||||
{
|
||||
"_raise_exceptions_for_gated_repo": False,
|
||||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_raise_exceptions_for_connection_errors": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Let's start by checking whether the processor class is saved in a processor config
|
||||
processor_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, PROCESSOR_NAME, **get_file_from_repo_kwargs
|
||||
)
|
||||
processor_config_file = cached_file(pretrained_model_name_or_path, PROCESSOR_NAME, **cached_file_kwargs)
|
||||
if processor_config_file is not None:
|
||||
config_dict, _ = ProcessorMixin.get_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
processor_class = config_dict.get("processor_class", None)
|
||||
@ -271,8 +277,8 @@ class AutoProcessor:
|
||||
|
||||
if processor_class is None:
|
||||
# If not found, let's check whether the processor class is saved in an image processor config
|
||||
preprocessor_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs
|
||||
preprocessor_config_file = cached_file(
|
||||
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **cached_file_kwargs
|
||||
)
|
||||
if preprocessor_config_file is not None:
|
||||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
@ -291,8 +297,8 @@ class AutoProcessor:
|
||||
|
||||
if processor_class is None:
|
||||
# Next, let's check whether the processor class is saved in a tokenizer
|
||||
tokenizer_config_file = get_file_from_repo(
|
||||
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
|
||||
tokenizer_config_file = cached_file(
|
||||
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **cached_file_kwargs
|
||||
)
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as reader:
|
||||
|
@ -25,7 +25,7 @@ import numpy as np
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...utils import logging
|
||||
from ...utils.hub import get_file_from_repo
|
||||
from ...utils.hub import cached_file
|
||||
from ..auto import AutoTokenizer
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ class BarkProcessor(ProcessorMixin):
|
||||
"""
|
||||
|
||||
if speaker_embeddings_dict_path is not None:
|
||||
speaker_embeddings_path = get_file_from_repo(
|
||||
speaker_embeddings_path = cached_file(
|
||||
pretrained_processor_name_or_path,
|
||||
speaker_embeddings_dict_path,
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
@ -97,6 +97,9 @@ class BarkProcessor(ProcessorMixin):
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("use_auth_token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if speaker_embeddings_path is None:
|
||||
logger.warning(
|
||||
@ -182,7 +185,7 @@ class BarkProcessor(ProcessorMixin):
|
||||
f"Voice preset unrecognized, missing {key} as a key in self.speaker_embeddings[{voice_preset}]."
|
||||
)
|
||||
|
||||
path = get_file_from_repo(
|
||||
path = cached_file(
|
||||
self.speaker_embeddings.get("repo_or_path", "/"),
|
||||
voice_preset_paths[key],
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
@ -193,6 +196,9 @@ class BarkProcessor(ProcessorMixin):
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("use_auth_token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if path is None:
|
||||
raise ValueError(
|
||||
|
@ -544,7 +544,7 @@ class CvtPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, CvtStage):
|
||||
if self.config.cls_token[module.stage]:
|
||||
module.cls_token.data = nn.init.trunc_normal_(
|
||||
torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=self.config.initializer_range
|
||||
module.cls_token.data, mean=0.0, std=self.config.initializer_range
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,7 +35,7 @@ from torch import Tensor
|
||||
from vissl.models.model_helpers import get_trunk_forward_outputs
|
||||
|
||||
from transformers import AutoImageProcessor, RegNetConfig, RegNetForImageClassification, RegNetModel
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@ -244,14 +244,18 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
|
||||
our_model_func = RegNetModel
|
||||
if "in1k" in model_name:
|
||||
our_model_func = RegNetForImageClassification
|
||||
our_model = our_model_func(our_config)
|
||||
# place our model to the meta device (so remove all the weights)
|
||||
our_model.to(torch.device("meta"))
|
||||
with torch.device("meta"):
|
||||
our_model = our_model_func(our_config)
|
||||
logger.info("Loading state_dict in our model.")
|
||||
# load state dict
|
||||
state_dict_keys = our_model.state_dict().keys()
|
||||
PreTrainedModel._load_pretrained_model_low_mem(
|
||||
our_model, state_dict_keys, [save_directory / f"{model_name}.pth"]
|
||||
state_dict = load_state_dict(save_directory / f"{model_name}.pth", weights_only=True)
|
||||
fixed_state_dict = state_dict = {our_model._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
|
||||
_load_state_dict_into_meta_model(
|
||||
our_model,
|
||||
fixed_state_dict,
|
||||
start_prefix="",
|
||||
expected_keys=state_dict_keys,
|
||||
)
|
||||
logger.info("Finally, pushing!")
|
||||
# push it to hub
|
||||
|
@ -113,7 +113,7 @@ class TimmWrapperPreTrainedModel(PreTrainedModel):
|
||||
Override original method to fix state_dict keys on load for cases when weights are loaded
|
||||
without using the `from_pretrained` method (e.g., in Trainer to resume from checkpoint).
|
||||
"""
|
||||
state_dict = self._fix_state_dict_keys_on_load(state_dict)
|
||||
state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()}
|
||||
return super().load_state_dict(state_dict, *args, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -91,7 +91,6 @@ from .hub import (
|
||||
define_sagemaker_information,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
get_file_from_repo,
|
||||
has_file,
|
||||
http_user_agent,
|
||||
is_offline_mode,
|
||||
|
@ -40,6 +40,7 @@ from huggingface_hub import (
|
||||
create_repo,
|
||||
hf_hub_download,
|
||||
hf_hub_url,
|
||||
snapshot_download,
|
||||
try_to_load_from_cache,
|
||||
)
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
|
||||
@ -47,7 +48,6 @@ from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
GatedRepoError,
|
||||
HfHubHTTPError,
|
||||
HFValidationError,
|
||||
LocalEntryNotFoundError,
|
||||
OfflineModeIsEnabled,
|
||||
RepositoryNotFoundError,
|
||||
@ -69,7 +69,6 @@ from .import_utils import (
|
||||
is_torch_available,
|
||||
is_training_run_on_sagemaker,
|
||||
)
|
||||
from .logging import tqdm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@ -209,21 +208,7 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
|
||||
def cached_file(
|
||||
path_or_repo_id: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
repo_type: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_gated_repo: bool = True,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
_commit_hash: Optional[str] = None,
|
||||
**deprecated_kwargs,
|
||||
**kwargs,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
@ -231,7 +216,6 @@ def cached_file(
|
||||
Args:
|
||||
path_or_repo_id (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a model repo on huggingface.co.
|
||||
- a path to a *directory* potentially containing the file.
|
||||
filename (`str`):
|
||||
@ -274,6 +258,94 @@ def cached_file(
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download a model weight from the Hub and cache it.
|
||||
model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
|
||||
```
|
||||
"""
|
||||
file = cached_files(path_or_repo_id=path_or_repo_id, filenames=[filename], **kwargs)
|
||||
file = file[0] if file is not None else file
|
||||
return file
|
||||
|
||||
|
||||
def cached_files(
|
||||
path_or_repo_id: Union[str, os.PathLike],
|
||||
filenames: List[str],
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
repo_type: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_gated_repo: bool = True,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
_commit_hash: Optional[str] = None,
|
||||
**deprecated_kwargs,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Tries to locate several files in a local folder and repo, downloads and cache them if necessary.
|
||||
|
||||
Args:
|
||||
path_or_repo_id (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
- a string, the *model id* of a model repo on huggingface.co.
|
||||
- a path to a *directory* potentially containing the file.
|
||||
filenames (`List[str]`):
|
||||
The name of all the files to locate in `path_or_repo`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
||||
cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
||||
exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
Private args:
|
||||
_raise_exceptions_for_gated_repo (`bool`):
|
||||
if False, do not raise an exception for gated repo error but return None.
|
||||
_raise_exceptions_for_missing_entries (`bool`):
|
||||
if False, do not raise an exception for missing entries but return None.
|
||||
_raise_exceptions_for_connection_errors (`bool`):
|
||||
if False, do not raise an exception for connection errors but return None.
|
||||
_commit_hash (`str`, *optional*):
|
||||
passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
|
||||
a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Download a model weight from the Hub and cache it.
|
||||
model_weights_file = cached_file("google-bert/bert-base-uncased", "pytorch_model.bin")
|
||||
@ -289,144 +361,176 @@ def cached_file(
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
# Private arguments
|
||||
# _raise_exceptions_for_gated_repo: if False, do not raise an exception for gated repo error but return
|
||||
# None.
|
||||
# _raise_exceptions_for_missing_entries: if False, do not raise an exception for missing entries but return
|
||||
# None.
|
||||
# _raise_exceptions_for_connection_errors: if False, do not raise an exception for connection errors but return
|
||||
# None.
|
||||
# _commit_hash: passed when we are chaining several calls to various files (e.g. when loading a tokenizer or
|
||||
# a pipeline). If files are cached for this commit hash, avoid calls to head and get from the cache.
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
if subfolder is None:
|
||||
subfolder = ""
|
||||
|
||||
# Add folder to filenames
|
||||
full_filenames = [os.path.join(subfolder, file) for file in filenames]
|
||||
|
||||
path_or_repo_id = str(path_or_repo_id)
|
||||
full_filename = os.path.join(subfolder, filename)
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
|
||||
if not os.path.isfile(resolved_file):
|
||||
if _raise_exceptions_for_missing_entries and filename not in ["config.json", f"{subfolder}/config.json"]:
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
|
||||
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return resolved_file
|
||||
existing_files = []
|
||||
for filename in full_filenames:
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
resolved_file = os.path.join(path_or_repo_id, filename)
|
||||
if not os.path.isfile(resolved_file):
|
||||
if _raise_exceptions_for_missing_entries and filename != os.path.join(subfolder, "config.json"):
|
||||
revision_ = "main" if revision is None else revision
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} does not appear to have a file named {filename}. Checkout "
|
||||
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision_}' for available files."
|
||||
)
|
||||
else:
|
||||
return None
|
||||
existing_files.append(resolved_file)
|
||||
|
||||
# All files exist
|
||||
if len(existing_files) == len(full_filenames):
|
||||
return existing_files
|
||||
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
existing_files = []
|
||||
file_counter = 0
|
||||
if _commit_hash is not None and not force_download:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
if resolved_file is not None:
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
return resolved_file
|
||||
elif not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
else:
|
||||
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
|
||||
for filename in full_filenames:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(
|
||||
path_or_repo_id, filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
if resolved_file is not None:
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
file_counter += 1
|
||||
existing_files.append(resolved_file)
|
||||
elif not _raise_exceptions_for_missing_entries:
|
||||
file_counter += 1
|
||||
else:
|
||||
raise EnvironmentError(f"Could not locate {filename} inside {path_or_repo_id}.")
|
||||
|
||||
# Either all the files were found, or some were _CACHED_NO_EXIST but we do not raise for missing entries
|
||||
if file_counter == len(full_filenames):
|
||||
return existing_files if len(existing_files) > 0 else None
|
||||
|
||||
user_agent = http_user_agent(user_agent)
|
||||
# download the files if needed
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_file = hf_hub_download(
|
||||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
if len(full_filenames) == 1:
|
||||
# This is slightly better for only 1 file
|
||||
hf_hub_download(
|
||||
path_or_repo_id,
|
||||
filenames[0],
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
else:
|
||||
snapshot_download(
|
||||
path_or_repo_id,
|
||||
allow_patterns=full_filenames,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# We cannot recover from them
|
||||
if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
|
||||
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
|
||||
"`token=<your_token>`"
|
||||
) from e
|
||||
elif isinstance(e, RevisionNotFoundError):
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||
) from e
|
||||
|
||||
# Now we try to recover if we can find all files correctly in the cache
|
||||
resolved_files = [
|
||||
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
|
||||
]
|
||||
if all(file is not None for file in resolved_files):
|
||||
return resolved_files
|
||||
|
||||
# Raise based on the flags. Note that we will raise for missing entries at the very end, even when
|
||||
# not entering this Except block, as it may also happen when `snapshot_download` does not raise
|
||||
if isinstance(e, GatedRepoError):
|
||||
if not _raise_exceptions_for_gated_repo:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
"You are trying to access a gated repo.\nMake sure to have access to it at "
|
||||
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
|
||||
) from e
|
||||
elif isinstance(e, LocalEntryNotFoundError):
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
# Here we only raise if both flags for missing entry and connection errors are True (because it can be raised
|
||||
# even when `local_files_only` is True, in which case raising for connections errors only would not make sense)
|
||||
elif _raise_exceptions_for_missing_entries:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load the files, and couldn't find them in the"
|
||||
f" cached files.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
) from e
|
||||
# snapshot_download will not raise EntryNotFoundError, but hf_hub_download can. If this is the case, it will be treated
|
||||
# later on anyway and re-raised if needed
|
||||
elif isinstance(e, HTTPError) and not isinstance(e, EntryNotFoundError):
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {path_or_repo_id}:\n{e}"
|
||||
)
|
||||
|
||||
resolved_files = [
|
||||
_get_cache_file_to_return(path_or_repo_id, filename, cache_dir, revision) for filename in full_filenames
|
||||
]
|
||||
# If there are any missing file and the flag is active, raise
|
||||
if any(file is None for file in resolved_files) and _raise_exceptions_for_missing_entries:
|
||||
missing_entries = [original for original, resolved in zip(full_filenames, resolved_files) if resolved is None]
|
||||
# Last escape
|
||||
if len(resolved_files) == 1 and missing_entries[0] == os.path.join(subfolder, "config.json"):
|
||||
return None
|
||||
# Now we raise for missing entries
|
||||
revision_ = "main" if revision is None else revision
|
||||
msg = f"a file named {missing_entries[0]}" if len(missing_entries) == 1 else f"files named {*missing_entries,}"
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} does not appear to have {msg}. Checkout 'https://huggingface.co/{path_or_repo_id}/tree/{revision_}'"
|
||||
"for available files."
|
||||
)
|
||||
except GatedRepoError as e:
|
||||
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
|
||||
if resolved_file is not None or not _raise_exceptions_for_gated_repo:
|
||||
return resolved_file
|
||||
raise EnvironmentError(
|
||||
"You are trying to access a gated repo.\nMake sure to have access to it at "
|
||||
f"https://huggingface.co/{path_or_repo_id}.\n{str(e)}"
|
||||
) from e
|
||||
except RepositoryNotFoundError as e:
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token "
|
||||
"having permission to this repo either by logging in with `huggingface-cli login` or by passing "
|
||||
"`token=<your_token>`"
|
||||
) from e
|
||||
except RevisionNotFoundError as e:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
|
||||
) from e
|
||||
except LocalEntryNotFoundError as e:
|
||||
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
|
||||
if (
|
||||
resolved_file is not None
|
||||
or not _raise_exceptions_for_missing_entries
|
||||
or not _raise_exceptions_for_connection_errors
|
||||
):
|
||||
return resolved_file
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
|
||||
f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
|
||||
f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
|
||||
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
) from e
|
||||
except EntryNotFoundError as e:
|
||||
if not _raise_exceptions_for_missing_entries:
|
||||
return None
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
if filename in ["config.json", f"{subfolder}/config.json"]:
|
||||
return None
|
||||
raise EnvironmentError(
|
||||
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
|
||||
f"'https://huggingface.co/{path_or_repo_id}/tree/{revision}' for available files."
|
||||
) from e
|
||||
except HTTPError as err:
|
||||
resolved_file = _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)
|
||||
if resolved_file is not None or not _raise_exceptions_for_connection_errors:
|
||||
return resolved_file
|
||||
raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
|
||||
except HFValidationError as e:
|
||||
raise EnvironmentError(
|
||||
f"Incorrect path_or_model_id: '{path_or_repo_id}'. Please provide either the path to a local folder or the repo_id of a model on the Hub."
|
||||
) from e
|
||||
return resolved_file
|
||||
|
||||
# Remove potential missing entries (we can silently remove them at this point based on the flags)
|
||||
resolved_files = [file for file in resolved_files if file is not None]
|
||||
# Return `None` if the list is empty, coherent with other Exception when the flag is not active
|
||||
resolved_files = None if len(resolved_files) == 0 else resolved_files
|
||||
|
||||
return resolved_files
|
||||
|
||||
|
||||
# TODO: deprecate `get_file_from_repo` or document it differently?
|
||||
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
|
||||
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
|
||||
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
|
||||
# TODO cyril: Deprecated and should be removed in 4.51
|
||||
def get_file_from_repo(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
||||
force_download: bool = False,
|
||||
resume_download: Optional[bool] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
**deprecated_kwargs,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
|
||||
@ -483,30 +587,15 @@ def get_file_from_repo(
|
||||
tokenizer_config = get_file_from_repo("FacebookAI/xlm-roberta-base", "tokenizer_config.json")
|
||||
```
|
||||
"""
|
||||
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if token is not None:
|
||||
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
||||
token = use_auth_token
|
||||
|
||||
logger.warning(
|
||||
"`get_file_from_repo` is deprecated and will be removed in version 4.51. Use `cached_file` instead."
|
||||
)
|
||||
return cached_file(
|
||||
path_or_repo_id=path_or_repo,
|
||||
filename=filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
subfolder=subfolder,
|
||||
*args,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -1023,45 +1112,22 @@ def get_checkpoint_shard_files(
|
||||
shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
|
||||
return shard_filenames, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
cached_filenames = []
|
||||
# Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
|
||||
# downloaded (if interrupted).
|
||||
last_shard = try_to_load_from_cache(
|
||||
pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache,
|
||||
# or download the files
|
||||
cached_filenames = cached_files(
|
||||
pretrained_model_name_or_path,
|
||||
shard_filenames,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
show_progress_bar = last_shard is None or force_download
|
||||
for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
|
||||
try:
|
||||
# Load from URL
|
||||
cached_filename = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
shard_filename,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
|
||||
# we don't have to catch them here.
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {shard_filename} which is "
|
||||
"required according to the checkpoint index."
|
||||
)
|
||||
except HTTPError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {shard_filename}. You should try"
|
||||
" again after checking your internet connection."
|
||||
)
|
||||
|
||||
cached_filenames.append(cached_filename)
|
||||
|
||||
return cached_filenames, sharded_metadata
|
||||
|
||||
|
@ -2368,10 +2368,9 @@ class ModelTesterMixin:
|
||||
safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"})
|
||||
model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True)
|
||||
|
||||
prefix = f"{model_reloaded.base_model_prefix}."
|
||||
params = dict(model_reloaded.named_parameters())
|
||||
params.update(dict(model_reloaded.named_buffers()))
|
||||
param_names = {k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys()}
|
||||
param_names = set(params.keys())
|
||||
|
||||
missing_keys = set(infos["missing_keys"])
|
||||
|
||||
@ -2383,9 +2382,8 @@ class ModelTesterMixin:
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
for group in tied_params:
|
||||
group = {k[len(prefix) :] if k.startswith(prefix) else k for k in group}
|
||||
# We remove the group from extra_missing if not all weights from group are in it
|
||||
if len(group - extra_missing) > 0:
|
||||
if len(set(group) - extra_missing) > 0:
|
||||
extra_missing = extra_missing - set(group)
|
||||
|
||||
self.assertEqual(
|
||||
@ -2399,15 +2397,14 @@ class ModelTesterMixin:
|
||||
# Remove nonpersistent buffers from missed_missing
|
||||
buffers = [n for n, _ in model_reloaded.named_buffers()]
|
||||
nonpersistent_buffers = {n for n in buffers if n not in model_reloaded.state_dict()}
|
||||
nonpersistent_buffers = {
|
||||
k[len(prefix) :] if k.startswith(prefix) else k for k in nonpersistent_buffers
|
||||
}
|
||||
missed_missing = missed_missing - nonpersistent_buffers
|
||||
|
||||
if model_reloaded._keys_to_ignore_on_load_missing is None:
|
||||
expected_missing = set()
|
||||
else:
|
||||
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
|
||||
expected_missing = set()
|
||||
for pattern in model_reloaded._keys_to_ignore_on_load_missing:
|
||||
expected_missing.update({k for k in param_names if re.search(pattern, k) is not None})
|
||||
self.assertEqual(
|
||||
missed_missing,
|
||||
expected_missing,
|
||||
|
@ -28,7 +28,6 @@ from transformers.utils import (
|
||||
TRANSFORMERS_CACHE,
|
||||
WEIGHTS_NAME,
|
||||
cached_file,
|
||||
get_file_from_repo,
|
||||
has_file,
|
||||
)
|
||||
|
||||
@ -87,14 +86,8 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
|
||||
self.assertIsNone(path)
|
||||
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
# Under the mock environment, hf_hub_download will always raise an HTTPError
|
||||
with mock.patch("transformers.utils.hub.hf_hub_download", side_effect=HTTPError) as mock_head:
|
||||
path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
|
||||
self.assertIsNone(path)
|
||||
# This check we did call the fake head request
|
||||
@ -117,18 +110,45 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
|
||||
|
||||
def test_get_file_from_repo_distant(self):
|
||||
# `get_file_from_repo` returns None if the file does not exist
|
||||
self.assertIsNone(get_file_from_repo("google-bert/bert-base-cased", "ahah.txt"))
|
||||
# should return None if the file does not exist
|
||||
self.assertIsNone(
|
||||
cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
"ahah.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
|
||||
# The function raises if the repository does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
|
||||
get_file_from_repo("bert-base-case", CONFIG_NAME)
|
||||
cached_file(
|
||||
"bert-base-case",
|
||||
CONFIG_NAME,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
# The function raises if the revision does not exist.
|
||||
with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
|
||||
get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME, revision="ahaha")
|
||||
cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
CONFIG_NAME,
|
||||
revision="ahaha",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
resolved_file = get_file_from_repo("google-bert/bert-base-cased", CONFIG_NAME)
|
||||
resolved_file = cached_file(
|
||||
"google-bert/bert-base-cased",
|
||||
CONFIG_NAME,
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
# The name is the cached name which is not very easy to test, so instead we load the content.
|
||||
config = json.loads(open(resolved_file, "r").read())
|
||||
self.assertEqual(config["hidden_size"], 768)
|
||||
@ -137,9 +157,26 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
filename = Path(tmp_dir) / "a.txt"
|
||||
filename.touch()
|
||||
self.assertEqual(get_file_from_repo(tmp_dir, "a.txt"), str(filename))
|
||||
self.assertEqual(
|
||||
cached_file(
|
||||
tmp_dir,
|
||||
"a.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
),
|
||||
str(filename),
|
||||
)
|
||||
|
||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||
self.assertIsNone(
|
||||
cached_file(
|
||||
tmp_dir,
|
||||
"b.txt",
|
||||
_raise_exceptions_for_gated_repo=False,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
)
|
||||
|
||||
def test_get_file_gated_repo(self):
|
||||
"""Test download file from a gated repo fails with correct message when not authenticated."""
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import glob
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
@ -525,13 +524,12 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
|
||||
|
||||
# TODO @ARTHURZUCKER FIX THIS
|
||||
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
|
||||
# LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
|
||||
# model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
# self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
# self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
# self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
|
||||
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
@ -540,20 +538,6 @@ class ModelUtilsTest(TestCasePlus):
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@unittest.skip("Broken by @arthurzucker because the fix was not correct. Knowing the context is super hard")
|
||||
def test_model_from_pretrained_meta_device(self):
|
||||
def is_on_meta(model_id, dtype):
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype)
|
||||
return all(value.device.type == "meta" for value in model.state_dict().values())
|
||||
|
||||
model_ids = ("fxmarty/tiny-llama-fast-tokenizer", "fxmarty/small-llama-testing")
|
||||
dtypes = (None, "auto", torch.float16)
|
||||
|
||||
for model_id, dtype in itertools.product(model_ids, dtypes):
|
||||
self.assertTrue(is_on_meta(model_id, dtype))
|
||||
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
|
Loading…
Reference in New Issue
Block a user