Align try_to_load_from_cache with huggingface_hub (#18966)

* Align try_to_load_from_cache with huggingface_hub

* Fix tests
This commit is contained in:
Sylvain Gugger 2022-09-12 12:09:37 -04:00 committed by GitHub
parent cf450b776f
commit f7ceda345d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -222,18 +222,27 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
def try_to_load_from_cache(
repo_id: str,
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision.
Explores the cache to return the latest cached file for a given revision if found.
This function will not raise any exception if the file in not cached.
Args:
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
repo_id (`str`): The ID of the repo on huggingface.co.
filename (`str`): The filename to look for inside `repo_id`.
cache_dir (`str` or `os.PathLike`):
The folder where the cached files lie.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
@ -242,36 +251,36 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
if revision is None and commit_hash is None:
if revision is None:
revision = "main"
model_id = repo_id.replace("/", "--")
model_cache = os.path.join(cache_dir, f"models--{model_id}")
if not os.path.isdir(model_cache):
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
object_id = repo_id.replace("/", "--")
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
for subfolder in ["refs", "snapshots"]:
if not os.path.isdir(os.path.join(model_cache, subfolder)):
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
return None
if commit_hash is None:
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(model_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(model_cache, "refs", revision)) as f:
commit_hash = f.read()
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(repo_cache, "refs", revision)) as f:
revision = f.read()
if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)):
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
return _CACHED_NO_EXIST
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if commit_hash not in cached_shas:
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
cached_file = os.path.join(model_cache, "snapshots", commit_hash, filename)
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
return cached_file if os.path.isfile(cached_file) else None
@ -375,7 +384,9 @@ def cached_file(
if _commit_hash is not None:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
return resolved_file
@ -416,7 +427,7 @@ def cached_file(
)
except LocalEntryNotFoundError:
# We try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_missing_entries or not _raise_exceptions_for_connection_errors:
@ -438,7 +449,7 @@ def cached_file(
)
except HTTPError as err:
# First we try to see if we have a cached version (not up to date):
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)
if resolved_file is not None:
return resolved_file
if not _raise_exceptions_for_connection_errors: