diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9971ec81faf..cb7a2c2e289 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -51,6 +51,7 @@ from .pytorch_utils import ( # noqa: F401 from .utils import ( ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, DUMMY_INPUTS, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, @@ -65,6 +66,7 @@ from .utils import ( cached_file, copy_func, download_url, + extract_commit_hash, has_file, is_accelerate_available, is_auto_gptq_available, @@ -2368,13 +2370,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " ignored." ) + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + if is_peft_available() and _adapter_model_path is None: maybe_adapter_model_path = find_adapter_config_file( pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, revision=revision, subfolder=subfolder, - token=token, - commit_hash=commit_hash, + _commit_hash=commit_hash, ) elif is_peft_available() and _adapter_model_path is not None: maybe_adapter_model_path = _adapter_model_path @@ -2622,9 +2650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix " `pip install --upgrade bitsandbytes`." ) - if commit_hash is None: - commit_hash = getattr(config, "_commit_hash", None) - # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # index of the files. is_sharded = False diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 3e7b89f68ff..daca460ebbc 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -22,7 +22,16 @@ from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code -from ...utils import copy_func, find_adapter_config_file, is_peft_available, logging, requires_backends +from ...utils import ( + CONFIG_NAME, + cached_file, + copy_func, + extract_commit_hash, + find_adapter_config_file, + is_peft_available, + logging, + requires_backends, +) from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings @@ -443,7 +452,6 @@ class _BaseAutoModelClass: kwargs["_from_auto"] = True hub_kwargs_names = [ "cache_dir", - "code_revision", "force_download", "local_files_only", "proxies", @@ -454,6 +462,8 @@ class _BaseAutoModelClass: "token", ] hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + code_revision = kwargs.pop("code_revision", None) + commit_hash = kwargs.pop("_commit_hash", None) token = hub_kwargs.pop("token", None) use_auth_token = hub_kwargs.pop("use_auth_token", None) @@ -470,12 +480,23 @@ class _BaseAutoModelClass: if token is not None: hub_kwargs["token"] = token - if is_peft_available(): - revision = kwargs.get("revision", None) - subfolder = kwargs.get("subfolder", None) + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + **hub_kwargs, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + if is_peft_available(): maybe_adapter_path = find_adapter_config_file( - pretrained_model_name_or_path, revision=revision, token=token, subfolder=subfolder + pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs ) if maybe_adapter_path is not None: @@ -499,6 +520,8 @@ class _BaseAutoModelClass: pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, + code_revision=code_revision, + _commit_hash=commit_hash, **hub_kwargs, **kwargs, ) @@ -517,7 +540,7 @@ class _BaseAutoModelClass: if has_remote_code and trust_remote_code: class_ref = config.auto_map[cls.__name__] model_class = get_class_from_dynamic_module( - class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs + class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs ) _ = hub_kwargs.pop("code_revision", None) if os.path.isdir(pretrained_model_name_or_path): diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index a345235951d..ba6041f44f6 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -1007,6 +1007,8 @@ class AutoConfig: kwargs["_from_auto"] = True kwargs["name_or_path"] = pretrained_model_name_or_path trust_remote_code = kwargs.pop("trust_remote_code", None) + code_revision = kwargs.pop("code_revision", None) + config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING @@ -1016,10 +1018,11 @@ class AutoConfig: if has_remote_code and trust_remote_code: class_ref = config_dict["auto_map"]["AutoConfig"] - config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) + config_class = get_class_from_dynamic_module( + class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs + ) if os.path.isdir(pretrained_model_name_or_path): config_class.register_for_auto_class() - _ = kwargs.pop("code_revision", None) return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) elif "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py index 5d68f6539e3..0e20db8ea06 100644 --- a/src/transformers/utils/peft_utils.py +++ b/src/transformers/utils/peft_utils.py @@ -13,7 +13,7 @@ # limitations under the License. import importlib import os -from typing import Optional +from typing import Dict, Optional, Union from packaging import version @@ -28,10 +28,15 @@ ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" def find_adapter_config_file( model_id: str, - revision: str = None, - subfolder: str = None, - token: Optional[str] = None, - commit_hash: Optional[str] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + resume_download: bool = False, + proxies: Optional[Dict[str, str]] = None, + token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + local_files_only: bool = False, + subfolder: str = "", + _commit_hash: Optional[str] = None, ) -> Optional[str]: r""" Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter @@ -40,6 +45,20 @@ def find_adapter_config_file( Args: model_id (`str`): The identifier of the model to look for, can be either a local path or an id to the repository on the Hub. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the standard + cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if they + exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any @@ -51,12 +70,11 @@ def find_adapter_config_file( + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, will only try to load the tokenizer configuration from local files. subfolder (`str`, *optional*, defaults to `""`): In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here. - token (`str`, `optional`): - Whether to use authentication token to load the remote folder. Userful to load private repositories that - are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to cache it. """ adapter_cached_filename = None if model_id is None: @@ -69,10 +87,15 @@ def find_adapter_config_file( adapter_cached_filename = cached_file( model_id, ADAPTER_CONFIG_NAME, - revision=revision, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, token=token, - _commit_hash=commit_hash, + revision=revision, + local_files_only=local_files_only, subfolder=subfolder, + _commit_hash=_commit_hash, _raise_exceptions_for_missing_entries=False, _raise_exceptions_for_connection_errors=False, )