mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Refine errors for pretrained objects (#15261)
* Refine errors for pretrained objects * PoC to avoid using get_list_of_files * Adapt tests to use new errors * Quality + Fix PoC * Revert "PoC to avoid using get_list_of_files" This reverts commitcb93b7cae8
. * Revert "Quality + Fix PoC" This reverts commit3ba6d0d4ca
. * Fix doc * Revert PoC * Add feature extractors * More tests and PT model * Adapt error message * Feature extractor tests * TF model * Flax model and test * Merge flax auto tests * Add tokenization * Fix test
This commit is contained in:
parent
80af1048cf
commit
6ac77534bf
@ -25,10 +25,15 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .file_utils import (
|
||||
CONFIG_NAME,
|
||||
EntryNotFoundError,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
get_list_of_files,
|
||||
@ -520,8 +525,6 @@ class PretrainedConfig(PushToHubMixin):
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
[`PretrainedConfig`] using `from_dict`.
|
||||
|
||||
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
@ -578,30 +581,51 @@ class PretrainedConfig(PushToHubMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||
f"{pretrained_model_name_or_path} is not the path to a directory conaining a {configuration_file} "
|
||||
"file.\nCheckout your internet connection or see how to run the library in offline mode at "
|
||||
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {configuration_file} file"
|
||||
)
|
||||
|
||||
if revision is not None:
|
||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
msg = (
|
||||
f"Couldn't reach server at '{config_file}' to download configuration file or "
|
||||
"configuration file is not a valid JSON file. "
|
||||
f"Please check network or file content here: {resolved_config_file}."
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info(f"loading configuration file {config_file}")
|
||||
@ -842,9 +866,13 @@ def get_configuration_file(
|
||||
`str`: The configuration file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
try:
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
return FULL_CONFIGURATION_FILE
|
||||
|
||||
configuration_files_map = {}
|
||||
for file_name in all_files:
|
||||
search = _re_configuration_file.search(file_name)
|
||||
|
@ -24,8 +24,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from .file_utils import (
|
||||
FEATURE_EXTRACTOR_NAME,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
TensorType,
|
||||
_is_jax,
|
||||
_is_numpy,
|
||||
@ -374,28 +379,54 @@ class FeatureExtractionMixin:
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||
f"{pretrained_model_name_or_path} is not the path to a directory conaining a "
|
||||
f"{FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run the library in "
|
||||
"offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
|
||||
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {FEATURE_EXTRACTOR_NAME} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load feature_extractor dict
|
||||
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
feature_extractor_dict = json.loads(text)
|
||||
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load feature extractor for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {FEATURE_EXTRACTOR_NAME} file\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
msg = (
|
||||
f"Couldn't reach server at '{feature_extractor_file}' to download feature extractor configuration file or "
|
||||
"feature extractor configuration file is not a valid JSON file. "
|
||||
f"Please check network or file content here: {resolved_feature_extractor_file}."
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_feature_extractor_file == feature_extractor_file:
|
||||
logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
|
||||
|
@ -1900,6 +1900,37 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
class RepositoryNotFoundError(HTTPError):
|
||||
"""
|
||||
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
|
||||
not have access to.
|
||||
"""
|
||||
|
||||
|
||||
class EntryNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
|
||||
|
||||
|
||||
class RevisionNotFoundError(HTTPError):
|
||||
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
|
||||
|
||||
|
||||
def _raise_for_status(request):
|
||||
"""
|
||||
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
|
||||
"""
|
||||
if "X-Error-Code" in request.headers:
|
||||
error_code = request.headers["X-Error-Code"]
|
||||
if error_code == "RepoNotFound":
|
||||
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {request.url}")
|
||||
elif error_code == "EntryNotFound":
|
||||
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {request.url}")
|
||||
elif error_code == "RevisionNotFound":
|
||||
raise RevisionNotFoundError((f"404 Client Error: Revision Not Found for url: {request.url}"))
|
||||
|
||||
request.raise_for_status()
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Download remote file. Do not gobble up errors.
|
||||
@ -1908,7 +1939,7 @@ def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers
|
||||
if resume_size > 0:
|
||||
headers["Range"] = f"bytes={resume_size}-"
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
r.raise_for_status()
|
||||
_raise_for_status(r)
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
|
||||
@ -1970,7 +2001,7 @@ def get_from_cache(
|
||||
if not local_files_only:
|
||||
try:
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||
r.raise_for_status()
|
||||
_raise_for_status(r)
|
||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||
# We favor a custom header indicating the etag of the linked resource, and
|
||||
# we fallback to the regular etag header.
|
||||
@ -2081,6 +2112,56 @@ def get_from_cache(
|
||||
return cache_path
|
||||
|
||||
|
||||
def has_file(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
filename: str,
|
||||
revision: Optional[str] = None,
|
||||
mirror: Optional[str] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
):
|
||||
"""
|
||||
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders.
|
||||
|
||||
<Tip warning={false}>
|
||||
|
||||
This function will raise an error if the repository `path_or_repo` is not valid or if `revision` does not exist for
|
||||
this repo, but will return False for regular connection errors.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
if os.path.isdir(path_or_repo):
|
||||
return os.path.isfile(os.path.join(path_or_repo, filename))
|
||||
|
||||
url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
|
||||
|
||||
headers = {"user-agent": http_user_agent()}
|
||||
if isinstance(use_auth_token, str):
|
||||
headers["authorization"] = f"Bearer {use_auth_token}"
|
||||
elif use_auth_token:
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
||||
headers["authorization"] = f"Bearer {token}"
|
||||
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
|
||||
try:
|
||||
_raise_for_status(r)
|
||||
return True
|
||||
except RepositoryNotFoundError as e:
|
||||
logger.error(e)
|
||||
raise EnvironmentError(f"{path_or_repo} is not a local folder or a valid repository name on 'https://hf.co'.")
|
||||
except RevisionNotFoundError as e:
|
||||
logger.error(e)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
"model name. Check the model page at 'https://huggingface.co/{path_or_repo}' for available revisions."
|
||||
)
|
||||
except requests.HTTPError:
|
||||
# We return false for EntryNotFoundError (logical) as well as any connection error.
|
||||
return False
|
||||
|
||||
|
||||
def get_list_of_files(
|
||||
path_or_repo: Union[str, os.PathLike],
|
||||
revision: Optional[str] = None,
|
||||
|
@ -26,16 +26,21 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
from requests import HTTPError
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
cached_path,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
@ -450,17 +455,25 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
|
||||
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
@ -476,15 +489,59 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError as err:
|
||||
logger.error(err)
|
||||
if filename == FLAX_WEIGHTS_NAME:
|
||||
has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
|
||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
||||
"those weights."
|
||||
)
|
||||
else:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME} "
|
||||
f"or {WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
|
||||
f"{FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
|
||||
"Checkout your internet connection or see how to run the library in offline mode at "
|
||||
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
|
@ -32,16 +32,21 @@ from tensorflow.python.keras.engine.keras_tensor import KerasTensor
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
|
||||
from huggingface_hub import Repository, list_repo_files
|
||||
from requests import HTTPError
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
@ -1542,19 +1547,27 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME]} found in directory "
|
||||
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
||||
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
@ -1571,15 +1584,65 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError as err:
|
||||
logger.error(err)
|
||||
if filename == TF2_WEIGHTS_NAME:
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
||||
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from "
|
||||
"those weights."
|
||||
)
|
||||
else:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {TF2_WEIGHTS_NAME} "
|
||||
f"or {WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
|
||||
f"{TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\n"
|
||||
"Checkout your internet connection or see how to run the library in offline mode at "
|
||||
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
|
@ -27,6 +27,8 @@ from packaging import version
|
||||
from torch import Tensor, device, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
@ -36,10 +38,14 @@ from .file_utils import (
|
||||
TF2_WEIGHTS_NAME,
|
||||
TF_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
EntryNotFoundError,
|
||||
ModelOutput,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
copy_func,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
@ -1292,10 +1298,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
elif os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
|
||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
|
||||
f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
|
||||
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
|
||||
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
@ -1334,20 +1355,72 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError as err:
|
||||
logger.error(err)
|
||||
if filename == WEIGHTS_NAME:
|
||||
has_file_kwargs = {
|
||||
"revision": revision,
|
||||
"mirror": mirror,
|
||||
"proxies": proxies,
|
||||
"use_auth_token": use_auth_token,
|
||||
}
|
||||
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME} but "
|
||||
"there is a file for Flax weights. Use `from_flax=True` to load this model from those "
|
||||
"weights."
|
||||
)
|
||||
else:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}, "
|
||||
f"{TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
"We couldn't connect to 'https://huggingface.co/' to load this model and it looks like "
|
||||
f"{pretrained_model_name_or_path} is not the path to a directory conaining a a file named "
|
||||
f"{WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}.\n"
|
||||
"Checkout your internet connection or see how to run the library in offline mode at "
|
||||
"'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}\n\n"
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
|
||||
f"{FLAX_WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
if revision is not None:
|
||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
|
@ -18,13 +18,13 @@ import importlib
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
get_list_of_files,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_sentencepiece_available,
|
||||
@ -333,16 +333,6 @@ def get_tokenizer_config(
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
# Will raise a ValueError if `pretrained_model_name_or_path` is not a valid path or model identifier
|
||||
repo_files = get_list_of_files(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if TOKENIZER_CONFIG_FILE not in [Path(f).name for f in repo_files]:
|
||||
return {}
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
|
||||
@ -363,6 +353,21 @@ def get_tokenizer_config(
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EnvironmentError:
|
||||
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
|
||||
return {}
|
||||
|
@ -31,13 +31,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequenc
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .file_utils import (
|
||||
EntryNotFoundError,
|
||||
ExplicitEnum,
|
||||
PaddingStrategy,
|
||||
PushToHubMixin,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
TensorType,
|
||||
_is_jax,
|
||||
_is_numpy,
|
||||
@ -1704,9 +1707,28 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
else:
|
||||
raise error
|
||||
|
||||
except requests.exceptions.HTTPError as err:
|
||||
except RepositoryNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
|
||||
"pass a token having permission to this repo with `use_auth_token` or log in with "
|
||||
"`huggingface-cli login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError as err:
|
||||
logger.error(err)
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
|
||||
"for this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
|
||||
resolved_vocab_files[file_id] = None
|
||||
|
||||
except HTTPError as err:
|
||||
if "404 Client Error" in str(err):
|
||||
logger.debug(err)
|
||||
logger.debug(f"Connection problem to access {file_path}.")
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
raise err
|
||||
@ -1718,18 +1740,13 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||
msg = (
|
||||
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
||||
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
|
||||
raise EnvironmentError(
|
||||
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing all relevant tokenizer files."
|
||||
)
|
||||
|
||||
if revision is not None:
|
||||
msg += f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
for file_id, file_path in vocab_files.items():
|
||||
if file_id not in resolved_vocab_files:
|
||||
continue
|
||||
@ -3504,9 +3521,13 @@ def get_fast_tokenizer_file(
|
||||
`str`: The tokenizer file to use.
|
||||
"""
|
||||
# Inspect all files from the repo/folder.
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
try:
|
||||
all_files = get_list_of_files(
|
||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
return FULL_TOKENIZER_FILE
|
||||
|
||||
tokenizer_files_map = {}
|
||||
for file_name in all_files:
|
||||
search = _re_tokenizer_file.search(file_name)
|
||||
|
@ -83,3 +83,22 @@ class AutoConfigTest(unittest.TestCase):
|
||||
finally:
|
||||
if "new-model" in CONFIG_MAPPING._extra_content:
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
):
|
||||
_ = AutoConfig.from_pretrained("bert-base")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoConfig.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_configuration_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
|
||||
):
|
||||
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
|
||||
|
@ -19,6 +19,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
@ -62,3 +63,22 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
def test_feature_extractor_from_local_file(self):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
):
|
||||
_ = AutoFeatureExtractor.from_pretrained("bert-base")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoFeatureExtractor.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_feature_extractor_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"hf-internal-testing/config-no-model does not appear to have a file named preprocessor_config.json.",
|
||||
):
|
||||
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
|
||||
|
@ -17,17 +17,22 @@ import importlib
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
import transformers
|
||||
|
||||
# Try to import everything from transformers to ensure every object can be loaded.
|
||||
from transformers import * # noqa F406
|
||||
from transformers.file_utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
ContextManagers,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
filename_to_url,
|
||||
get_from_cache,
|
||||
has_file,
|
||||
hf_bucket_url,
|
||||
)
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
|
||||
@ -83,13 +88,19 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
def test_file_not_found(self):
|
||||
# Valid revision (None) but missing file.
|
||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_model_not_found(self):
|
||||
# Invalid model file.
|
||||
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
|
||||
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_revision_not_found(self):
|
||||
# Valid file but missing revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_standard_object(self):
|
||||
@ -112,6 +123,11 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
||||
|
||||
def test_has_file(self):
|
||||
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
|
||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
|
||||
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", FLAX_WEIGHTS_NAME))
|
||||
|
||||
|
||||
class ContextManagerTests(unittest.TestCase):
|
||||
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
|
||||
|
@ -389,3 +389,30 @@ class AutoModelTest(unittest.TestCase):
|
||||
):
|
||||
if NewModelConfig in mapping._extra_content:
|
||||
del mapping._extra_content[NewModelConfig]
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
):
|
||||
_ = AutoModel.from_pretrained("bert-base")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_model_file_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"hf-internal-testing/config-no-model does not appear to have a file named pytorch_model.bin",
|
||||
):
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||
|
||||
def test_model_from_tf_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_tf=True` to load this model"):
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||
|
||||
def test_model_from_flax_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
@ -15,7 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@ -76,3 +76,26 @@ class FlaxAutoModelTest(unittest.TestCase):
|
||||
return model(**kwargs)
|
||||
|
||||
eval(**tokens).block_until_ready()
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
):
|
||||
_ = FlaxAutoModel.from_pretrained("bert-base")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = FlaxAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_model_file_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"hf-internal-testing/config-no-model does not appear to have a file named flax_model.msgpack",
|
||||
):
|
||||
_ = FlaxAutoModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||
|
||||
def test_model_from_pt_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||
_ = FlaxAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
@ -309,3 +309,26 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
):
|
||||
if NewModelConfig in mapping._extra_content:
|
||||
del mapping._extra_content[NewModelConfig]
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
):
|
||||
_ = TFAutoModel.from_pretrained("bert-base")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = TFAutoModel.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_model_file_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError,
|
||||
"hf-internal-testing/config-no-model does not appear to have a file named tf_model.h5",
|
||||
):
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/config-no-model")
|
||||
|
||||
def test_model_from_pt_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
@ -150,7 +150,8 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
def test_tokenizer_identifier_non_existent(self):
|
||||
for tokenizer_class in [BertTokenizer, BertTokenizerFast, AutoTokenizer]:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, ".*is not a local path or a model identifier on the model Hub. Did you make a typo?"
|
||||
EnvironmentError,
|
||||
"julien-c/herlolip-not-exists is not a local folder and is not a valid model identifier",
|
||||
):
|
||||
_ = tokenizer_class.from_pretrained("julien-c/herlolip-not-exists")
|
||||
|
||||
@ -310,3 +311,15 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
del CONFIG_MAPPING._extra_content["new-model"]
|
||||
if NewConfig in TOKENIZER_MAPPING._extra_content:
|
||||
del TOKENIZER_MAPPING._extra_content[NewConfig]
|
||||
|
||||
def test_repo_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, "bert-base is not a local folder and is not a valid model identifier"
|
||||
):
|
||||
_ = AutoTokenizer.from_pretrained("bert-base")
|
||||
|
||||
def test_revision_not_found(self):
|
||||
with self.assertRaisesRegex(
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
@ -255,7 +255,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
|
||||
"modeling_tf_utils.py": ["test_modeling_tf_common.py", "test_modeling_tf_core.py"],
|
||||
"modeling_utils.py": ["test_modeling_common.py", "test_offline.py"],
|
||||
"models/auto/modeling_auto.py": ["test_modeling_auto.py", "test_modeling_tf_pytorch.py", "test_modeling_bort.py"],
|
||||
"models/auto/modeling_flax_auto.py": "test_flax_auto.py",
|
||||
"models/auto/modeling_flax_auto.py": "test_modeling_flax_auto.py",
|
||||
"models/auto/modeling_tf_auto.py": [
|
||||
"test_modeling_tf_auto.py",
|
||||
"test_modeling_tf_pytorch.py",
|
||||
|
Loading…
Reference in New Issue
Block a user