Little cleanup: let huggingface_hub manage token retrieval (#21333)

* Let huggingface_hub manage token retrieval

* flake8

* code quality

* adapt in every PushToHubMixin children

* add explicit return type
This commit is contained in:
Lucain 2023-01-27 18:09:49 +01:00 committed by GitHub
parent 0dff407d71
commit 8f3b4a1d5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 76 additions and 55 deletions

View File

@ -438,7 +438,7 @@ class PretrainedConfig(PushToHubMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
@ -454,7 +454,11 @@ class PretrainedConfig(PushToHubMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
@classmethod @classmethod

View File

@ -22,7 +22,7 @@ import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info from huggingface_hub import model_info
from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging
@ -251,14 +251,7 @@ def get_cached_module_file(
else: else:
# Get the commit hash # Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here. # TODO: we will get this info in the etag soon, so retrieve it from there and not here.
if isinstance(use_auth_token, str): commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=use_auth_token).sha
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
else:
token = None
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
# benefit of versioning. # benefit of versioning.

View File

@ -353,7 +353,7 @@ class FeatureExtractionMixin(PushToHubMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
@ -369,7 +369,11 @@ class FeatureExtractionMixin(PushToHubMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
return [output_feature_extractor_file] return [output_feature_extractor_file]

View File

@ -337,7 +337,7 @@ class GenerationConfig(PushToHubMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
output_config_file = os.path.join(save_directory, config_file_name) output_config_file = os.path.join(save_directory, config_file_name)
@ -347,7 +347,11 @@ class GenerationConfig(PushToHubMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
@classmethod @classmethod

View File

@ -185,7 +185,7 @@ class ImageProcessingMixin(PushToHubMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
@ -201,7 +201,11 @@ class ImageProcessingMixin(PushToHubMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
return [output_image_processor_file] return [output_image_processor_file]

View File

@ -1018,7 +1018,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
# get abs dir # get abs dir
@ -1077,7 +1077,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
@classmethod @classmethod

View File

@ -2277,7 +2277,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
if saved_model: if saved_model:
@ -2363,7 +2363,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
@classmethod @classmethod
@ -2946,7 +2950,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else: else:
working_dir = repo_id.split("/")[-1] working_dir = repo_id.split("/")[-1]
repo_id, token = self._create_repo( repo_id = self._create_repo(
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
) )
@ -2968,7 +2972,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.create_model_card(**base_model_card_args) self.create_model_card(**base_model_card_args)
self._upload_modified_files( self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
) )
@classmethod @classmethod

View File

@ -1633,7 +1633,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
# Only save the model itself if we are using distributed training # Only save the model itself if we are using distributed training
@ -1717,7 +1717,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
def get_memory_footprint(self, return_buffers=True): def get_memory_footprint(self, return_buffers=True):

View File

@ -121,7 +121,7 @@ class ProcessorMixin(PushToHubMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub. # loaded from the Hub.
@ -147,7 +147,11 @@ class ProcessorMixin(PushToHubMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
@classmethod @classmethod

View File

@ -2098,7 +2098,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if push_to_hub: if push_to_hub:
commit_message = kwargs.pop("commit_message", None) commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs) repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory) files_timestamps = self._get_files_timestamps(save_directory)
special_tokens_map_file = os.path.join( special_tokens_map_file = os.path.join(
@ -2177,7 +2177,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if push_to_hub: if push_to_hub:
self._upload_modified_files( self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
) )
return save_files return save_files

View File

@ -31,7 +31,6 @@ import huggingface_hub
import requests import requests
from huggingface_hub import ( from huggingface_hub import (
CommitOperationAdd, CommitOperationAdd,
HfFolder,
create_commit, create_commit,
create_repo, create_repo,
get_hf_file_metadata, get_hf_file_metadata,
@ -45,6 +44,7 @@ from huggingface_hub.utils import (
LocalEntryNotFoundError, LocalEntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
build_hf_headers,
hf_raise_for_status, hf_raise_for_status,
) )
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
@ -583,7 +583,7 @@ def has_file(
use_auth_token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None,
): ):
""" """
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders. Checks if a repo contains a given file without downloading it. Works for remote repos and local folders.
<Tip warning={false}> <Tip warning={false}>
@ -596,15 +596,7 @@ def has_file(
return os.path.isfile(os.path.join(path_or_repo, filename)) return os.path.isfile(os.path.join(path_or_repo, filename))
url = hf_hub_url(path_or_repo, filename=filename, revision=revision) url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
headers = build_hf_headers(use_auth_token=use_auth_token, user_agent=http_user_agent())
headers = {"user-agent": http_user_agent()}
if isinstance(use_auth_token, str):
headers["authorization"] = f"Bearer {use_auth_token}"
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = f"Bearer {token}"
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10) r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try: try:
@ -636,10 +628,10 @@ class PushToHubMixin:
use_auth_token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None,
repo_url: Optional[str] = None, repo_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
): ) -> str:
""" """
Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
token. the token.
""" """
if repo_url is not None: if repo_url is not None:
warnings.warn( warnings.warn(
@ -657,13 +649,12 @@ class PushToHubMixin:
repo_id = repo_id.split("/")[-1] repo_id = repo_id.split("/")[-1]
repo_id = f"{organization}/{repo_id}" repo_id = f"{organization}/{repo_id}"
token = HfFolder.get_token() if use_auth_token is True else use_auth_token url = create_repo(repo_id=repo_id, token=use_auth_token, private=private, exist_ok=True)
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
# If the namespace is not there, add it or `upload_file` will complain # If the namespace is not there, add it or `upload_file` will complain
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}": if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
repo_id = get_full_repo_name(repo_id, token=token) repo_id = get_full_repo_name(repo_id, token=use_auth_token)
return repo_id, token return repo_id
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]): def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
""" """
@ -677,7 +668,7 @@ class PushToHubMixin:
repo_id: str, repo_id: str,
files_timestamps: Dict[str, float], files_timestamps: Dict[str, float],
commit_message: Optional[str] = None, commit_message: Optional[str] = None,
token: Optional[str] = None, token: Optional[Union[bool, str]] = None,
create_pr: bool = False, create_pr: bool = False,
): ):
""" """
@ -776,7 +767,7 @@ class PushToHubMixin:
else: else:
working_dir = repo_id.split("/")[-1] working_dir = repo_id.split("/")[-1]
repo_id, token = self._create_repo( repo_id = self._create_repo(
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
) )
@ -790,13 +781,16 @@ class PushToHubMixin:
self.save_pretrained(work_dir, max_shard_size=max_shard_size) self.save_pretrained(work_dir, max_shard_size=max_shard_size)
return self._upload_modified_files( return self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr work_dir,
repo_id,
files_timestamps,
commit_message=commit_message,
token=use_auth_token,
create_pr=create_pr,
) )
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None: if organization is None:
username = whoami(token)["name"] username = whoami(token)["name"]
return f"{username}/{model_id}" return f"{username}/{model_id}"
@ -1040,8 +1034,6 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
cache_dir = str(old_cache) cache_dir = str(old_cache)
else: else:
cache_dir = new_cache_dir cache_dir = new_cache_dir
if token is None:
token = HfFolder.get_token()
cached_files = get_all_cached_files(cache_dir=cache_dir) cached_files = get_all_cached_files(cache_dir=cache_dir)
logger.info(f"Moving {len(cached_files)} files to the new cache system") logger.info(f"Moving {len(cached_files)} files to the new cache system")
@ -1050,7 +1042,7 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
url = file_info.pop("url") url = file_info.pop("url")
if url not in hub_metadata: if url not in hub_metadata:
try: try:
hub_metadata[url] = get_hf_file_metadata(url, use_auth_token=token) hub_metadata[url] = get_hf_file_metadata(url, token=token)
except requests.HTTPError: except requests.HTTPError:
continue continue