mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Align try_to_load_from_cache with huggingface_hub (#18966)
* Align try_to_load_from_cache with huggingface_hub * Fix tests
This commit is contained in:
parent
cf450b776f
commit
f7ceda345d
@ -222,18 +222,27 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
|
||||
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
||||
|
||||
|
||||
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
|
||||
def try_to_load_from_cache(
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision.
|
||||
Explores the cache to return the latest cached file for a given revision if found.
|
||||
|
||||
This function will not raise any exception if the file in not cached.
|
||||
|
||||
Args:
|
||||
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
|
||||
repo_id (`str`): The ID of the repo on huggingface.co.
|
||||
filename (`str`): The filename to look for inside `repo_id`.
|
||||
cache_dir (`str` or `os.PathLike`):
|
||||
The folder where the cached files lie.
|
||||
repo_id (`str`):
|
||||
The ID of the repo on huggingface.co.
|
||||
filename (`str`):
|
||||
The filename to look for inside `repo_id`.
|
||||
revision (`str`, *optional*):
|
||||
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
||||
provided either.
|
||||
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
|
||||
|
||||
Returns:
|
||||
`Optional[str]` or `_CACHED_NO_EXIST`:
|
||||
@ -242,36 +251,36 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
|
||||
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
|
||||
cached.
|
||||
"""
|
||||
if commit_hash is not None and revision is not None:
|
||||
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
|
||||
if revision is None and commit_hash is None:
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
||||
model_id = repo_id.replace("/", "--")
|
||||
model_cache = os.path.join(cache_dir, f"models--{model_id}")
|
||||
if not os.path.isdir(model_cache):
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
|
||||
object_id = repo_id.replace("/", "--")
|
||||
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
|
||||
if not os.path.isdir(repo_cache):
|
||||
# No cache for this model
|
||||
return None
|
||||
for subfolder in ["refs", "snapshots"]:
|
||||
if not os.path.isdir(os.path.join(model_cache, subfolder)):
|
||||
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
|
||||
return None
|
||||
|
||||
if commit_hash is None:
|
||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
|
||||
if revision in cached_refs:
|
||||
with open(os.path.join(model_cache, "refs", revision)) as f:
|
||||
commit_hash = f.read()
|
||||
# Resolve refs (for instance to convert main to the associated commit sha)
|
||||
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
|
||||
if revision in cached_refs:
|
||||
with open(os.path.join(repo_cache, "refs", revision)) as f:
|
||||
revision = f.read()
|
||||
|
||||
if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)):
|
||||
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
|
||||
return _CACHED_NO_EXIST
|
||||
|
||||
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
|
||||
if commit_hash not in cached_shas:
|
||||
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
|
||||
if revision not in cached_shas:
|
||||
# No cache for this revision and we won't try to return a random revision
|
||||
return None
|
||||
|
||||
cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
|
||||
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
|
||||
return cached_file if os.path.isfile(cached_file) else None
|
||||
|
||||
|
||||
@ -375,7 +384,9 @@ def cached_file(
|
||||
|
||||
if _commit_hash is not None:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
|
||||
resolved_file = try_to_load_from_cache(
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
|
||||
)
|
||||
if resolved_file is not None:
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
return resolved_file
|
||||
@ -416,7 +427,7 @@ def cached_file(
|
||||
)
|
||||
except LocalEntryNotFoundError:
|
||||
# We try to see if we have a cached version (not up to date):
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
|
||||
@ -438,7 +449,7 @@ def cached_file(
|
||||
)
|
||||
except HTTPError as err:
|
||||
# First we try to see if we have a cached version (not up to date):
|
||||
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
|
||||
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
|
||||
if resolved_file is not None:
|
||||
return resolved_file
|
||||
if not _raise_exceptions_for_connection_errors:
|
||||
|
Loading…
Reference in New Issue
Block a user