Model versioning (#8324)

* fix typo

* rm use_cdn & references, and implement new hf_bucket_url

* I'm pretty sure we don't need to `read` this file

* same here

* [BIG] file_utils.networking: do not gobble up errors anymore

* Fix CI 😇

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Tiny doc tweak

* Add doc + pass kwarg everywhere

* Add more tests and explain

cc @sshleifer let me know if better

Co-Authored-By: Sam Shleifer <sshleifer@gmail.com>

* Also implement revision in pipelines

In the case where we're passing a task name or a string model identifier

* Fix CI 😇

* Fix CI

* [hf_api] new methods + command line implem

* make style

* Final endpoints post-migration

* Fix post-migration

* Py3.6 compat

cc @stefan-it

Thank you @stas00

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond 2020-11-10 13:11:02 +01:00 committed by GitHub
parent 4185b115d4
commit 70f622fab4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 472 additions and 210 deletions

View File

@ -12,8 +12,8 @@ inference: false
## Disclaimer
Due do it's immense size, `t5-11b` requires some special treatment.
First, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
**Before `transformers` v3.5.0**, due do its immense size, `t5-11b` required some special treatment.
If you're using transformers `<= v3.4.0`, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
```python
t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False)

View File

@ -56,7 +56,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere

View File

@ -44,7 +44,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere

View File

@ -55,7 +55,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere

View File

@ -1,4 +1,5 @@
import os
import subprocess
import sys
from argparse import ArgumentParser
from getpass import getpass
@ -21,8 +22,10 @@ class UserCommands(BaseTransformersCLICommand):
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
# s3
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
# s3_datasets (s3-based system)
s3_parser = parser.add_parser(
"s3_datasets", help="{ls, rm} Commands to interact with the files you upload on S3."
)
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
@ -31,17 +34,42 @@ class UserCommands(BaseTransformersCLICommand):
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload
upload_parser = parser.add_parser("upload", help="Upload a model to S3.")
upload_parser.add_argument(
"path", type=str, help="Local path of the model folder or individual file to upload."
)
upload_parser = s3_subparsers.add_parser("upload", help="Upload a file to S3.")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
)
upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
# deprecated model upload
upload_parser = parser.add_parser(
"upload",
help=(
"Deprecated: used to be the way to upload a model to S3."
" We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
),
)
upload_parser.set_defaults(func=lambda args: DeprecatedUploadCommand(args))
# new system: git-based repo system
repo_parser = parser.add_parser(
"repo", help="{create, ls-files} Commands to interact with your huggingface.co repos."
)
repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
ls_parser = repo_subparsers.add_parser("ls-files", help="List all your files on huggingface.co")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
ls_parser.set_defaults(func=lambda args: ListReposObjsCommand(args))
repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
repo_create_parser.add_argument(
"name",
type=str,
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
)
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
class ANSI:
@ -51,6 +79,7 @@ class ANSI:
_bold = "\u001b[1m"
_red = "\u001b[31m"
_gray = "\u001b[90m"
_reset = "\u001b[0m"
@classmethod
@ -61,6 +90,27 @@ class ANSI:
def red(cls, s):
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset)
@classmethod
def gray(cls, s):
return "{}{}{}".format(cls._gray, s, cls._reset)
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
class BaseUserCommand:
def __init__(self, args):
@ -124,22 +174,6 @@ class LogoutCommand(BaseUserCommand):
class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
def run(self):
token = HfFolder.get_token()
if token is None:
@ -155,7 +189,7 @@ class ListObjsCommand(BaseUserCommand):
print("No shared file yet")
exit()
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand):
@ -173,6 +207,85 @@ class DeleteObjCommand(BaseUserCommand):
print("Done")
class ListReposObjsCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
objs = self._api.list_repos_objs(token, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [[obj.filename, obj.lastModified, obj.commit, obj.size] for obj in objs]
print(tabulate(rows, headers=["Filename", "LastModified", "Commit-Sha", "Size"]))
class RepoCreateCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print("Looks like you do not have git installed, please install.")
try:
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print(
ANSI.red(
"Looks like you do not have git-lfs installed, please install."
" You can install from https://git-lfs.github.com/."
" Then run `git lfs install` (you only have to do this once)."
)
)
print("")
user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user
print("You are about to create {}".format(ANSI.bold(namespace + "/" + self.args.name)))
if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
try:
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("\nYour repo now lives at:")
print(" {}".format(ANSI.bold(url)))
print("\nYou can clone it locally with the command below," " and commit/push as usual.")
print(f"\n git clone {url}")
print("")
class DeprecatedUploadCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"Deprecated: used to be the way to upload a model to S3."
" We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
)
)
exit(1)
class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path):
"""

View File

@ -289,6 +289,10 @@ class AutoConfig:
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.

View File

@ -311,6 +311,10 @@ class PretrainedConfig(object):
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.
@ -362,6 +366,7 @@ class PretrainedConfig(object):
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
@ -369,7 +374,7 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
)
try:
@ -383,11 +388,10 @@ class PretrainedConfig(object):
local_files_only=local_files_only,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError:
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"

View File

@ -4,6 +4,7 @@ https://github.com/allenai/allennlp Copyright by the AllenNLP authors.
"""
import fnmatch
import io
import json
import os
import re
@ -17,7 +18,7 @@ from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
@ -217,6 +218,8 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
@ -825,34 +828,37 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
def hf_bucket_url(model_id: str, filename: str, revision: Optional[str] = None, mirror=None) -> str:
"""
Resolve a model identifier, and a file name, to a HF-hosted url on either S3 or Cloudfront (a Content Delivery
Network, or CDN).
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
to Cloudfront (a Content Delivery Network, or CDN) for large files.
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
bandwidth costs). However, it is more aggressively cached by default, so may not always reflect the latest changes
to the underlying file (default TTL is 24 hours).
bandwidth costs).
In terms of client-side caching from this library, even though Cloudfront relays the ETags from S3, using one or
the other (or switching from one to the other) will affect caching: cached files are not shared between the two
because the cached file's name contains a hash of the url.
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
"""
endpoint = (
PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror
else CLOUDFRONT_DISTRIB_PREFIX
if use_cdn
else S3_BUCKET_PREFIX
)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
if mirror:
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
else:
return f"{endpoint}/{model_id}/{filename}"
if revision is None:
revision = "main"
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
def url_to_filename(url, etag=None):
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
@ -860,13 +866,11 @@ def url_to_filename(url, etag=None):
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
filename = sha256(url_bytes).hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes)
filename += "." + etag_hash.hexdigest()
filename += "." + sha256(etag_bytes).hexdigest()
if url.endswith(".h5"):
filename += ".h5"
@ -927,8 +931,10 @@ def cached_path(
re-extract the archive and override the folder where it was extracted.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
otherwise
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
@ -992,7 +998,10 @@ def cached_path(
return output_path
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
ua += "; torch/{}".format(torch.__version__)
@ -1002,13 +1011,19 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
headers = {"user-agent": ua}
return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
"""
Donwload remote file. Do not gobble up errors.
"""
headers = {"user-agent": http_user_agent(user_agent)}
if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable
return
content_length = response.headers.get("Content-Length")
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
r.raise_for_status()
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm(
unit="B",
@ -1018,7 +1033,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
desc="Downloading",
disable=bool(logging.get_verbosity() == logging.NOTSET),
)
for chunk in response.iter_content(chunk_size=1024):
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
@ -1026,7 +1041,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
def get_from_cache(
url,
url: str,
cache_dir=None,
force_download=False,
proxies=None,
@ -1040,8 +1055,10 @@ def get_from_cache(
path to the cached file.
Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
otherwise
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
@ -1050,13 +1067,28 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True)
url_to_download = url
etag = None
if not local_files_only:
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
headers = {"user-agent": http_user_agent(user_agent)}
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
r.raise_for_status()
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
# In case of a redirect,
# save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# etag is already None
pass
@ -1065,7 +1097,7 @@ def get_from_cache(
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
@ -1088,7 +1120,11 @@ def get_from_cache(
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
return None
else:
raise ValueError(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
@ -1107,8 +1143,8 @@ def get_from_cache(
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager():
with open(incomplete_path, "a+b") as f:
def _resumable_file_manager() -> "io.BufferedWriter":
with open(incomplete_path, "ab") as f:
yield f
temp_file_manager = _resumable_file_manager
@ -1117,7 +1153,7 @@ def get_from_cache(
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
resume_size = 0
# Download to temporary file, then copy to cache dir once finished.
@ -1125,7 +1161,7 @@ def get_from_cache(
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)

View File

@ -27,9 +27,21 @@ import requests
ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class S3Obj:
"""
Data structure that represents a file belonging to the current user.
HuggingFace S3-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
@ -46,38 +58,25 @@ class PresignedUrl:
self.type = type # mime-type to send to S3.
class S3Object:
class ModelSibling:
"""
Data structure that represents a public file accessible on our S3.
Data structure that represents a public file inside a model, accessible from huggingface.co
"""
def __init__(
self,
key: str, # S3 object key
etag: str,
lastModified: str,
size: int,
rfilename: str, # filename relative to config.json
**kwargs
):
self.key = key
self.etag = etag
self.lastModified = lastModified
self.size = size
self.rfilename = rfilename
def __init__(self, rfilename: str, **kwargs):
self.rfilename = rfilename # filename relative to the model root
for k, v in kwargs.items():
setattr(self, k, v)
class ModelInfo:
"""
Info about a public model accessible from our S3.
Info about a public model accessible from huggingface.co
"""
def __init__(
self,
modelId: str, # id of model
key: str, # S3 object key of config.json
modelId: Optional[str] = None, # id of model
author: Optional[str] = None,
downloads: Optional[int] = None,
tags: List[str] = [],
@ -86,12 +85,11 @@ class ModelInfo:
**kwargs
):
self.modelId = modelId
self.key = key
self.author = author
self.downloads = downloads
self.tags = tags
self.pipeline_tag = pipeline_tag
self.siblings = [S3Object(**x) for x in siblings] if siblings is not None else None
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items():
setattr(self, k, v)
@ -134,9 +132,11 @@ class HfApi:
def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
"""
HuggingFace S3-based system, used for datasets and metrics.
Call HF API to get a presigned url to upload `filename` to S3.
"""
path = "{}/api/presign".format(self.endpoint)
path = "{}/api/datasets/presign".format(self.endpoint)
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
@ -148,6 +148,8 @@ class HfApi:
def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
"""
HuggingFace S3-based system, used for datasets and metrics.
Get a presigned url, then upload file to S3.
Outputs: url: Read-only url for the stored file on S3.
@ -169,9 +171,11 @@ class HfApi:
def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
"""
HuggingFace S3-based system, used for datasets and metrics.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = "{}/api/listObjs".format(self.endpoint)
path = "{}/api/datasets/listObjs".format(self.endpoint)
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
@ -180,9 +184,11 @@ class HfApi:
def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
"""
HuggingFace S3-based system, used for datasets and metrics.
Call HF API to delete a file stored by user
"""
path = "{}/api/deleteObj".format(self.endpoint)
path = "{}/api/datasets/deleteObj".format(self.endpoint)
r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
@ -200,6 +206,51 @@ class HfApi:
d = r.json()
return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = "{}/api/repos/ls".format(self.endpoint)
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(self, token: str, name: str, organization: Optional[str] = None) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
"""
path = "{}/api/repos/create".format(self.endpoint)
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
)
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = "{}/api/repos/delete".format(self.endpoint)
r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader:
"""

View File

@ -144,9 +144,7 @@ class ModelCard:
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
)
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
@ -156,8 +154,6 @@ class ModelCard:
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
if resolved_model_card_file is None:
raise EnvironmentError
if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file))
else:

View File

@ -537,9 +537,10 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or

View File

@ -107,7 +107,7 @@ class FlaxPreTrainedModel(ABC):
proxies = kwargs.pop("proxies", None)
# output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
revision = kwargs.pop("revision", None)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
@ -121,6 +121,7 @@ class FlaxPreTrainedModel(ABC):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
@ -131,7 +132,7 @@ class FlaxPreTrainedModel(ABC):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
# redirect to the cache, if necessary
try:
@ -143,16 +144,13 @@ class FlaxPreTrainedModel(ABC):
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
else:
msg = (
f"Model name '{pretrained_model_name_or_path}' "
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
f"We assumed '{archive_file}' was a path or url to model weight files but "
"couldn't find any such file at this path or url."
)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:

View File

@ -420,9 +420,10 @@ TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or

View File

@ -572,9 +572,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
@ -616,7 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
@ -631,6 +632,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
@ -659,7 +661,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
use_cdn=use_cdn,
revision=revision,
mirror=mirror,
)
@ -673,9 +675,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
resume_download=resume_download,
local_files_only=local_files_only,
)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"

View File

@ -813,9 +813,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
@ -857,7 +858,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
@ -872,6 +873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
else:
@ -909,7 +911,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
archive_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn,
revision=revision,
mirror=mirror,
)
@ -923,9 +925,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
resume_download=resume_download,
local_files_only=local_files_only,
)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"

View File

@ -86,7 +86,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def get_framework(model):
def get_framework(model, revision: Optional[str] = None):
"""
Select framework (TensorFlow or PyTorch) to use.
@ -103,14 +103,14 @@ def get_framework(model):
)
if isinstance(model, str):
if is_torch_available() and not is_tf_available():
model = AutoModel.from_pretrained(model)
model = AutoModel.from_pretrained(model, revision=revision)
elif is_tf_available() and not is_torch_available():
model = TFAutoModel.from_pretrained(model)
model = TFAutoModel.from_pretrained(model, revision=revision)
else:
try:
model = AutoModel.from_pretrained(model)
model = AutoModel.from_pretrained(model, revision=revision)
except OSError:
model = TFAutoModel.from_pretrained(model)
model = TFAutoModel.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework
@ -2730,6 +2730,7 @@ def pipeline(
config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = False,
**kwargs
) -> Pipeline:
@ -2784,6 +2785,10 @@ def pipeline(
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
is provided.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
When passing a task name or a string model identifier: The specific model version to use. It can be a
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
kwargs:
@ -2845,17 +2850,19 @@ def pipeline(
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], use_fast=use_fast, **tokenizer[1])
tokenizer = AutoTokenizer.from_pretrained(
tokenizer[0], use_fast=use_fast, revision=revision, **tokenizer[1]
)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=use_fast)
tokenizer = AutoTokenizer.from_pretrained(tokenizer, revision=revision, use_fast=use_fast)
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config)
config = AutoConfig.from_pretrained(config, revision=revision)
# Instantiate modelcard if needed
if isinstance(modelcard, str):
modelcard = ModelCard.from_pretrained(modelcard)
modelcard = ModelCard.from_pretrained(modelcard, revision=revision)
# Instantiate model if needed
if isinstance(model, str):
@ -2873,7 +2880,7 @@ def pipeline(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow."
)
model = model_class.from_pretrained(model, config=config, **model_kwargs)
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params:
if key.startswith("translation"):

View File

@ -125,8 +125,6 @@ class LegacyIndex(Index):
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(archive_file)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError:
msg = (
f"Can't load '{archive_file}'. Make sure that:\n\n"

View File

@ -276,6 +276,10 @@ class AutoTokenizer:
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`):

View File

@ -29,6 +29,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
import requests
from .file_utils import (
add_end_docstrings,
cached_path,
@ -1515,6 +1517,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__`` method.
kwargs (additional keyword arguments, `optional`):
@ -1549,6 +1555,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {}
@ -1601,18 +1608,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
pretrained_model_name_or_path, filename=file_name, revision=revision, mirror=None
)
vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case
try:
resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
resolved_vocab_files = {}
for file_id, file_path in vocab_files.items():
if file_path is None:
resolved_vocab_files[file_id] = None
else:
try:
resolved_vocab_files[file_id] = cached_path(
file_path,
cache_dir=cache_dir,
@ -1621,34 +1628,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError:
if pretrained_model_name_or_path in s3_models:
msg = "Couldn't reach server at '{}' to download vocabulary files."
else:
msg = (
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {}, but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
raise EnvironmentError(msg)
except requests.exceptions.HTTPError as err:
if "404 Client Error" in str(err):
logger.debug(err)
resolved_vocab_files[file_id] = None
else:
raise err
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
raise EnvironmentError(
"Model name '{}' was not found in tokenizers model name list ({}). "
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
"named {} but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
msg = (
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
)
raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]:

63
tests/test_file_utils.py Normal file
View File

@ -0,0 +1,63 @@
import unittest
import requests
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
# An actual model hosted on huggingface.co
REVISION_ID_DEFAULT = "main"
# Default branch name
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
# One particular commit (not the top of `main`)
REVISION_ID_INVALID = "aaaaaaa"
# This commit does not exist, so we should 404.
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
# Sha-1 of config.json on the top of `main`, for checking purposes
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
class GetFromCacheTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = get_from_cache(url)
def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
_ = get_from_cache(url)
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
_ = get_from_cache(url)
def test_standard_object(self):
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_standard_object_rev(self):
# Same object, but different revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def test_lfs_object(self):
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))

View File

@ -20,7 +20,7 @@ import unittest
import requests
from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, RepoObj, S3Obj
USER = "__DUMMY_TRANSFORMERS_USER__"
@ -35,6 +35,7 @@ FILES = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
),
]
REPO_NAME = "my-model-{}".format(int(time.time()))
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
@ -78,15 +79,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
self.assertIsInstance(urls, PresignedUrl)
def test_presign_invalid(self):
try:
_ = self._api.presign(token=self._token, filename="non_nested.json")
except HTTPError as e:
self.assertIsNotNone(e.response.text)
self.assertTrue("Filename invalid" in e.response.text)
else:
self.fail("Expected an exception")
def test_presign(self):
for FILE_KEY, FILE_PATH in FILES:
urls = self._api.presign(token=self._token, filename=FILE_KEY)
@ -109,6 +101,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
o = objs[-1]
self.assertIsInstance(o, S3Obj)
def test_list_repos_objs(self):
objs = self._api.list_repos_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, RepoObj)
def test_create_and_delete_repo(self):
self._api.create_repo(token=self._token, name=REPO_NAME)
self._api.delete_repo(token=self._token, name=REPO_NAME)
class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self):

View File

@ -323,7 +323,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
"jplu/tiny-tf-bert-random", output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]:

View File

@ -165,7 +165,7 @@ DUMMY_FUNCTION = {
def read_init():
""" Read the init and exctracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
lines = f.readlines()