Model versioning (#8324)

* fix typo

* rm use_cdn & references, and implement new hf_bucket_url

* I'm pretty sure we don't need to `read` this file

* same here

* [BIG] file_utils.networking: do not gobble up errors anymore

* Fix CI 😇

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Tiny doc tweak

* Add doc + pass kwarg everywhere

* Add more tests and explain

cc @sshleifer let me know if better

Co-Authored-By: Sam Shleifer <sshleifer@gmail.com>

* Also implement revision in pipelines

In the case where we're passing a task name or a string model identifier

* Fix CI 😇

* Fix CI

* [hf_api] new methods + command line implem

* make style

* Final endpoints post-migration

* Fix post-migration

* Py3.6 compat

cc @stefan-it

Thank you @stas00

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond 2020-11-10 13:11:02 +01:00 committed by GitHub
parent 4185b115d4
commit 70f622fab4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 472 additions and 210 deletions

View File

@ -12,8 +12,8 @@ inference: false
## Disclaimer ## Disclaimer
Due do it's immense size, `t5-11b` requires some special treatment. **Before `transformers` v3.5.0**, due do its immense size, `t5-11b` required some special treatment.
First, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows: If you're using transformers `<= v3.4.0`, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
```python ```python
t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False) t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False)

View File

@ -56,7 +56,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed # add/remove files as needed
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere

View File

@ -44,7 +44,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed # add/remove files as needed
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere

View File

@ -55,7 +55,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed # add/remove files as needed
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere

View File

@ -1,4 +1,5 @@
import os import os
import subprocess
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from getpass import getpass from getpass import getpass
@ -21,8 +22,10 @@ class UserCommands(BaseTransformersCLICommand):
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout", help="Log out") logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
# s3 # s3_datasets (s3-based system)
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.") s3_parser = parser.add_parser(
"s3_datasets", help="{ls, rm} Commands to interact with the files you upload on S3."
)
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands") s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls") ls_parser = s3_subparsers.add_parser("ls")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
@ -31,17 +34,42 @@ class UserCommands(BaseTransformersCLICommand):
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.") rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args)) rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload upload_parser = s3_subparsers.add_parser("upload", help="Upload a file to S3.")
upload_parser = parser.add_parser("upload", help="Upload a model to S3.") upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument(
"path", type=str, help="Local path of the model folder or individual file to upload."
)
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
upload_parser.add_argument( upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3." "--filename", type=str, default=None, help="Optional: override individual object filename on S3."
) )
upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt") upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
upload_parser.set_defaults(func=lambda args: UploadCommand(args)) upload_parser.set_defaults(func=lambda args: UploadCommand(args))
# deprecated model upload
upload_parser = parser.add_parser(
"upload",
help=(
"Deprecated: used to be the way to upload a model to S3."
" We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
),
)
upload_parser.set_defaults(func=lambda args: DeprecatedUploadCommand(args))
# new system: git-based repo system
repo_parser = parser.add_parser(
"repo", help="{create, ls-files} Commands to interact with your huggingface.co repos."
)
repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
ls_parser = repo_subparsers.add_parser("ls-files", help="List all your files on huggingface.co")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
ls_parser.set_defaults(func=lambda args: ListReposObjsCommand(args))
repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
repo_create_parser.add_argument(
"name",
type=str,
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
)
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
class ANSI: class ANSI:
@ -51,6 +79,7 @@ class ANSI:
_bold = "\u001b[1m" _bold = "\u001b[1m"
_red = "\u001b[31m" _red = "\u001b[31m"
_gray = "\u001b[90m"
_reset = "\u001b[0m" _reset = "\u001b[0m"
@classmethod @classmethod
@ -61,6 +90,27 @@ class ANSI:
def red(cls, s): def red(cls, s):
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset) return "{}{}{}".format(cls._bold + cls._red, s, cls._reset)
@classmethod
def gray(cls, s):
return "{}{}{}".format(cls._gray, s, cls._reset)
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
class BaseUserCommand: class BaseUserCommand:
def __init__(self, args): def __init__(self, args):
@ -124,22 +174,6 @@ class LogoutCommand(BaseUserCommand):
class ListObjsCommand(BaseUserCommand): class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
def run(self): def run(self):
token = HfFolder.get_token() token = HfFolder.get_token()
if token is None: if token is None:
@ -155,7 +189,7 @@ class ListObjsCommand(BaseUserCommand):
print("No shared file yet") print("No shared file yet")
exit() exit()
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs] rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])) print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
class DeleteObjCommand(BaseUserCommand): class DeleteObjCommand(BaseUserCommand):
@ -173,6 +207,85 @@ class DeleteObjCommand(BaseUserCommand):
print("Done") print("Done")
class ListReposObjsCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
objs = self._api.list_repos_objs(token, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [[obj.filename, obj.lastModified, obj.commit, obj.size] for obj in objs]
print(tabulate(rows, headers=["Filename", "LastModified", "Commit-Sha", "Size"]))
class RepoCreateCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print("Looks like you do not have git installed, please install.")
try:
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print(
ANSI.red(
"Looks like you do not have git-lfs installed, please install."
" You can install from https://git-lfs.github.com/."
" Then run `git lfs install` (you only have to do this once)."
)
)
print("")
user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user
print("You are about to create {}".format(ANSI.bold(namespace + "/" + self.args.name)))
if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
try:
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("\nYour repo now lives at:")
print(" {}".format(ANSI.bold(url)))
print("\nYou can clone it locally with the command below," " and commit/push as usual.")
print(f"\n git clone {url}")
print("")
class DeprecatedUploadCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"Deprecated: used to be the way to upload a model to S3."
" We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
)
)
exit(1)
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path): def walk_dir(self, rel_path):
""" """

View File

@ -289,6 +289,10 @@ class AutoConfig:
proxies (:obj:`Dict[str, str]`, `optional`): proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object. If :obj:`False`, then this function returns just the final configuration object.

View File

@ -311,6 +311,10 @@ class PretrainedConfig(object):
proxies (:obj:`Dict[str, str]`, `optional`): proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`): return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object. If :obj:`False`, then this function returns just the final configuration object.
@ -362,6 +366,7 @@ class PretrainedConfig(object):
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
@ -369,7 +374,7 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
else: else:
config_file = hf_bucket_url( config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
) )
try: try:
@ -383,11 +388,10 @@ class PretrainedConfig(object):
local_files_only=local_files_only, local_files_only=local_files_only,
) )
# Load config dict # Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file) config_dict = cls._dict_from_json_file(resolved_config_file)
except EnvironmentError: except EnvironmentError as err:
logger.error(err)
msg = ( msg = (
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"

View File

@ -4,6 +4,7 @@ https://github.com/allenai/allennlp Copyright by the AllenNLP authors.
""" """
import fnmatch import fnmatch
import io
import json import json
import os import os
import re import re
@ -17,7 +18,7 @@ from dataclasses import fields
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
@ -217,6 +218,8 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
PRESET_MIRROR_DICT = { PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models", "tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models", "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
@ -825,34 +828,37 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https") return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str: def hf_bucket_url(model_id: str, filename: str, revision: Optional[str] = None, mirror=None) -> str:
""" """
Resolve a model identifier, and a file name, to a HF-hosted url on either S3 or Cloudfront (a Content Delivery Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
Network, or CDN). to Cloudfront (a Content Delivery Network, or CDN) for large files.
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
bandwidth costs). However, it is more aggressively cached by default, so may not always reflect the latest changes bandwidth costs).
to the underlying file (default TTL is 24 hours).
In terms of client-side caching from this library, even though Cloudfront relays the ETags from S3, using one or Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
the other (or switching from one to the other) will affect caching: cached files are not shared between the two because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
because the cached file's name contains a hash of the url. in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
""" """
endpoint = ( if mirror:
PRESET_MIRROR_DICT.get(mirror, mirror) endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror legacy_format = "/" not in model_id
else CLOUDFRONT_DISTRIB_PREFIX if legacy_format:
if use_cdn return f"{endpoint}/{model_id}-{filename}"
else S3_BUCKET_PREFIX else:
) return f"{endpoint}/{model_id}/{filename}"
legacy_format = "/" not in model_id
if legacy_format: if revision is None:
return f"{endpoint}/{model_id}-{filename}" revision = "main"
else: return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
return f"{endpoint}/{model_id}/{filename}"
def url_to_filename(url, etag=None): def url_to_filename(url: str, etag: Optional[str] = None) -> str:
""" """
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's, Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
@ -860,13 +866,11 @@ def url_to_filename(url, etag=None):
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
""" """
url_bytes = url.encode("utf-8") url_bytes = url.encode("utf-8")
url_hash = sha256(url_bytes) filename = sha256(url_bytes).hexdigest()
filename = url_hash.hexdigest()
if etag: if etag:
etag_bytes = etag.encode("utf-8") etag_bytes = etag.encode("utf-8")
etag_hash = sha256(etag_bytes) filename += "." + sha256(etag_bytes).hexdigest()
filename += "." + etag_hash.hexdigest()
if url.endswith(".h5"): if url.endswith(".h5"):
filename += ".h5" filename += ".h5"
@ -927,8 +931,10 @@ def cached_path(
re-extract the archive and override the folder where it was extracted. re-extract the archive and override the folder where it was extracted.
Return: Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string) Local path (string) of file or if networking is off, last version of file cached on disk.
otherwise
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
@ -992,7 +998,10 @@ def cached_path(
return output_path return output_path
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None): def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available(): if is_torch_available():
ua += "; torch/{}".format(torch.__version__) ua += "; torch/{}".format(torch.__version__)
@ -1002,13 +1011,19 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str): elif isinstance(user_agent, str):
ua += "; " + user_agent ua += "; " + user_agent
headers = {"user-agent": ua} return ua
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
"""
Donwload remote file. Do not gobble up errors.
"""
headers = {"user-agent": http_user_agent(user_agent)}
if resume_size > 0: if resume_size > 0:
headers["Range"] = "bytes=%d-" % (resume_size,) headers["Range"] = "bytes=%d-" % (resume_size,)
response = requests.get(url, stream=True, proxies=proxies, headers=headers) r = requests.get(url, stream=True, proxies=proxies, headers=headers)
if response.status_code == 416: # Range not satisfiable r.raise_for_status()
return content_length = r.headers.get("Content-Length")
content_length = response.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None total = resume_size + int(content_length) if content_length is not None else None
progress = tqdm( progress = tqdm(
unit="B", unit="B",
@ -1018,7 +1033,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
desc="Downloading", desc="Downloading",
disable=bool(logging.get_verbosity() == logging.NOTSET), disable=bool(logging.get_verbosity() == logging.NOTSET),
) )
for chunk in response.iter_content(chunk_size=1024): for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
@ -1026,7 +1041,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
def get_from_cache( def get_from_cache(
url, url: str,
cache_dir=None, cache_dir=None,
force_download=False, force_download=False,
proxies=None, proxies=None,
@ -1040,8 +1055,10 @@ def get_from_cache(
path to the cached file. path to the cached file.
Return: Return:
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string) Local path (string) of file or if networking is off, last version of file cached on disk.
otherwise
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
@ -1050,13 +1067,28 @@ def get_from_cache(
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
url_to_download = url
etag = None etag = None
if not local_files_only: if not local_files_only:
try: try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) headers = {"user-agent": http_user_agent(user_agent)}
if response.status_code == 200: r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
etag = response.headers.get("ETag") r.raise_for_status()
except (EnvironmentError, requests.exceptions.Timeout): etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
# In case of a redirect,
# save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# etag is already None # etag is already None
pass pass
@ -1065,7 +1097,7 @@ def get_from_cache(
# get cache path to put the file # get cache path to put the file
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible. # etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one # try to get the last downloaded one
if etag is None: if etag is None:
if os.path.exists(cache_path): if os.path.exists(cache_path):
@ -1088,7 +1120,11 @@ def get_from_cache(
" disabled. To enable model look-ups and downloads online, set 'local_files_only'" " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False." " to False."
) )
return None else:
raise ValueError(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
# From now on, etag is not None. # From now on, etag is not None.
if os.path.exists(cache_path) and not force_download: if os.path.exists(cache_path) and not force_download:
@ -1107,8 +1143,8 @@ def get_from_cache(
incomplete_path = cache_path + ".incomplete" incomplete_path = cache_path + ".incomplete"
@contextmanager @contextmanager
def _resumable_file_manager(): def _resumable_file_manager() -> "io.BufferedWriter":
with open(incomplete_path, "a+b") as f: with open(incomplete_path, "ab") as f:
yield f yield f
temp_file_manager = _resumable_file_manager temp_file_manager = _resumable_file_manager
@ -1117,7 +1153,7 @@ def get_from_cache(
else: else:
resume_size = 0 resume_size = 0
else: else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False) temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
resume_size = 0 resume_size = 0
# Download to temporary file, then copy to cache dir once finished. # Download to temporary file, then copy to cache dir once finished.
@ -1125,7 +1161,7 @@ def get_from_cache(
with temp_file_manager() as temp_file: with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
logger.info("storing %s in cache at %s", url, cache_path) logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path) os.replace(temp_file.name, cache_path)

View File

@ -27,9 +27,21 @@ import requests
ENDPOINT = "https://huggingface.co" ENDPOINT = "https://huggingface.co"
class RepoObj:
"""
HuggingFace git-based system, data structure that represents a file belonging to the current user.
"""
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
self.filename = filename
self.lastModified = lastModified
self.commit = commit
self.size = size
class S3Obj: class S3Obj:
""" """
Data structure that represents a file belonging to the current user. HuggingFace S3-based system, data structure that represents a file belonging to the current user.
""" """
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs): def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
@ -46,38 +58,25 @@ class PresignedUrl:
self.type = type # mime-type to send to S3. self.type = type # mime-type to send to S3.
class S3Object: class ModelSibling:
""" """
Data structure that represents a public file accessible on our S3. Data structure that represents a public file inside a model, accessible from huggingface.co
""" """
def __init__( def __init__(self, rfilename: str, **kwargs):
self, self.rfilename = rfilename # filename relative to the model root
key: str, # S3 object key
etag: str,
lastModified: str,
size: int,
rfilename: str, # filename relative to config.json
**kwargs
):
self.key = key
self.etag = etag
self.lastModified = lastModified
self.size = size
self.rfilename = rfilename
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
class ModelInfo: class ModelInfo:
""" """
Info about a public model accessible from our S3. Info about a public model accessible from huggingface.co
""" """
def __init__( def __init__(
self, self,
modelId: str, # id of model modelId: Optional[str] = None, # id of model
key: str, # S3 object key of config.json
author: Optional[str] = None, author: Optional[str] = None,
downloads: Optional[int] = None, downloads: Optional[int] = None,
tags: List[str] = [], tags: List[str] = [],
@ -86,12 +85,11 @@ class ModelInfo:
**kwargs **kwargs
): ):
self.modelId = modelId self.modelId = modelId
self.key = key
self.author = author self.author = author
self.downloads = downloads self.downloads = downloads
self.tags = tags self.tags = tags
self.pipeline_tag = pipeline_tag self.pipeline_tag = pipeline_tag
self.siblings = [S3Object(**x) for x in siblings] if siblings is not None else None self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
@ -134,9 +132,11 @@ class HfApi:
def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl: def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
""" """
HuggingFace S3-based system, used for datasets and metrics.
Call HF API to get a presigned url to upload `filename` to S3. Call HF API to get a presigned url to upload `filename` to S3.
""" """
path = "{}/api/presign".format(self.endpoint) path = "{}/api/datasets/presign".format(self.endpoint)
r = requests.post( r = requests.post(
path, path,
headers={"authorization": "Bearer {}".format(token)}, headers={"authorization": "Bearer {}".format(token)},
@ -148,6 +148,8 @@ class HfApi:
def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str: def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
""" """
HuggingFace S3-based system, used for datasets and metrics.
Get a presigned url, then upload file to S3. Get a presigned url, then upload file to S3.
Outputs: url: Read-only url for the stored file on S3. Outputs: url: Read-only url for the stored file on S3.
@ -169,9 +171,11 @@ class HfApi:
def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]: def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
""" """
HuggingFace S3-based system, used for datasets and metrics.
Call HF API to list all stored files for user (or one of their organizations). Call HF API to list all stored files for user (or one of their organizations).
""" """
path = "{}/api/listObjs".format(self.endpoint) path = "{}/api/datasets/listObjs".format(self.endpoint)
params = {"organization": organization} if organization is not None else None params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)}) r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status() r.raise_for_status()
@ -180,9 +184,11 @@ class HfApi:
def delete_obj(self, token: str, filename: str, organization: Optional[str] = None): def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
""" """
HuggingFace S3-based system, used for datasets and metrics.
Call HF API to delete a file stored by user Call HF API to delete a file stored by user
""" """
path = "{}/api/deleteObj".format(self.endpoint) path = "{}/api/datasets/deleteObj".format(self.endpoint)
r = requests.delete( r = requests.delete(
path, path,
headers={"authorization": "Bearer {}".format(token)}, headers={"authorization": "Bearer {}".format(token)},
@ -200,6 +206,51 @@ class HfApi:
d = r.json() d = r.json()
return [ModelInfo(**x) for x in d] return [ModelInfo(**x) for x in d]
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
"""
HuggingFace git-based system, used for models.
Call HF API to list all stored files for user (or one of their organizations).
"""
path = "{}/api/repos/ls".format(self.endpoint)
params = {"organization": organization} if organization is not None else None
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
r.raise_for_status()
d = r.json()
return [RepoObj(**x) for x in d]
def create_repo(self, token: str, name: str, organization: Optional[str] = None) -> str:
"""
HuggingFace git-based system, used for models.
Call HF API to create a whole repo.
"""
path = "{}/api/repos/create".format(self.endpoint)
r = requests.post(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
)
r.raise_for_status()
d = r.json()
return d["url"]
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
"""
HuggingFace git-based system, used for models.
Call HF API to delete a whole repo.
CAUTION(this is irreversible).
"""
path = "{}/api/repos/delete".format(self.endpoint)
r = requests.delete(
path,
headers={"authorization": "Bearer {}".format(token)},
json={"name": name, "organization": organization},
)
r.raise_for_status()
class TqdmProgressFileReader: class TqdmProgressFileReader:
""" """

View File

@ -144,9 +144,7 @@ class ModelCard:
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path model_card_file = pretrained_model_name_or_path
else: else:
model_card_file = hf_bucket_url( model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME) model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
@ -156,8 +154,6 @@ class ModelCard:
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies) resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
if resolved_model_card_file is None:
raise EnvironmentError
if resolved_model_card_file == model_card_file: if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file)) logger.info("loading model card file {}".format(model_card_file))
else: else:

View File

@ -537,9 +537,10 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model). Whether or not to only look at local files (e.g., not try downloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
kwargs (additional keyword arguments, `optional`): kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or

View File

@ -107,7 +107,7 @@ class FlaxPreTrainedModel(ABC):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
# output_loading_info = kwargs.pop("output_loading_info", False) # output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) revision = kwargs.pop("revision", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
@ -121,6 +121,7 @@ class FlaxPreTrainedModel(ABC):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision,
**kwargs, **kwargs,
) )
else: else:
@ -131,7 +132,7 @@ class FlaxPreTrainedModel(ABC):
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
else: else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn) archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
@ -143,16 +144,13 @@ class FlaxPreTrainedModel(ABC):
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
except EnvironmentError: except EnvironmentError as err:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: logger.error(err)
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights." msg = (
else: f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
msg = ( f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"Model name '{pretrained_model_name_or_path}' " f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). " )
f"We assumed '{archive_file}' was a path or url to model weight files but "
"couldn't find any such file at this path or url."
)
raise EnvironmentError(msg) raise EnvironmentError(msg)
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:

View File

@ -420,9 +420,10 @@ TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try downloading the model). Whether or not to only look at local files (e.g., not try downloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
kwargs (additional keyword arguments, `optional`): kwargs (additional keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or

View File

@ -572,9 +572,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model). Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`): mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
@ -616,7 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
@ -631,6 +632,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision,
**kwargs, **kwargs,
) )
else: else:
@ -659,7 +661,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
archive_file = hf_bucket_url( archive_file = hf_bucket_url(
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
use_cdn=use_cdn, revision=revision,
mirror=mirror, mirror=mirror,
) )
@ -673,9 +675,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
if resolved_archive_file is None: except EnvironmentError as err:
raise EnvironmentError logger.error(err)
except EnvironmentError:
msg = ( msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"

View File

@ -813,9 +813,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model). Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`): mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
@ -857,7 +858,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
@ -872,6 +873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision,
**kwargs, **kwargs,
) )
else: else:
@ -909,7 +911,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
archive_file = hf_bucket_url( archive_file = hf_bucket_url(
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn, revision=revision,
mirror=mirror, mirror=mirror,
) )
@ -923,9 +925,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
if resolved_archive_file is None: except EnvironmentError as err:
raise EnvironmentError logger.error(err)
except EnvironmentError:
msg = ( msg = (
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"

View File

@ -86,7 +86,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_framework(model): def get_framework(model, revision: Optional[str] = None):
""" """
Select framework (TensorFlow or PyTorch) to use. Select framework (TensorFlow or PyTorch) to use.
@ -103,14 +103,14 @@ def get_framework(model):
) )
if isinstance(model, str): if isinstance(model, str):
if is_torch_available() and not is_tf_available(): if is_torch_available() and not is_tf_available():
model = AutoModel.from_pretrained(model) model = AutoModel.from_pretrained(model, revision=revision)
elif is_tf_available() and not is_torch_available(): elif is_tf_available() and not is_torch_available():
model = TFAutoModel.from_pretrained(model) model = TFAutoModel.from_pretrained(model, revision=revision)
else: else:
try: try:
model = AutoModel.from_pretrained(model) model = AutoModel.from_pretrained(model, revision=revision)
except OSError: except OSError:
model = TFAutoModel.from_pretrained(model) model = TFAutoModel.from_pretrained(model, revision=revision)
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt" framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
return framework return framework
@ -2730,6 +2730,7 @@ def pipeline(
config: Optional[Union[str, PretrainedConfig]] = None, config: Optional[Union[str, PretrainedConfig]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
revision: Optional[str] = None,
use_fast: bool = False, use_fast: bool = False,
**kwargs **kwargs
) -> Pipeline: ) -> Pipeline:
@ -2784,6 +2785,10 @@ def pipeline(
If no framework is specified, will default to the one currently installed. If no framework is specified and If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
is provided. is provided.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
When passing a task name or a string model identifier: The specific model version to use. It can be a
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`): use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`). Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
kwargs: kwargs:
@ -2845,17 +2850,19 @@ def pipeline(
if isinstance(tokenizer, tuple): if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs}) # For tuple we have (tokenizer name, {kwargs})
use_fast = tokenizer[1].pop("use_fast", use_fast) use_fast = tokenizer[1].pop("use_fast", use_fast)
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], use_fast=use_fast, **tokenizer[1]) tokenizer = AutoTokenizer.from_pretrained(
tokenizer[0], use_fast=use_fast, revision=revision, **tokenizer[1]
)
else: else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained(tokenizer, revision=revision, use_fast=use_fast)
# Instantiate config if needed # Instantiate config if needed
if isinstance(config, str): if isinstance(config, str):
config = AutoConfig.from_pretrained(config) config = AutoConfig.from_pretrained(config, revision=revision)
# Instantiate modelcard if needed # Instantiate modelcard if needed
if isinstance(modelcard, str): if isinstance(modelcard, str):
modelcard = ModelCard.from_pretrained(modelcard) modelcard = ModelCard.from_pretrained(modelcard, revision=revision)
# Instantiate model if needed # Instantiate model if needed
if isinstance(model, str): if isinstance(model, str):
@ -2873,7 +2880,7 @@ def pipeline(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. " "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow." "Trying to load the model with Tensorflow."
) )
model = model_class.from_pretrained(model, config=config, **model_kwargs) model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
if task == "translation" and model.config.task_specific_params: if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params: for key in model.config.task_specific_params:
if key.startswith("translation"): if key.startswith("translation"):

View File

@ -125,8 +125,6 @@ class LegacyIndex(Index):
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
resolved_archive_file = cached_path(archive_file) resolved_archive_file = cached_path(archive_file)
if resolved_archive_file is None:
raise EnvironmentError
except EnvironmentError: except EnvironmentError:
msg = ( msg = (
f"Can't load '{archive_file}'. Make sure that:\n\n" f"Can't load '{archive_file}'. Make sure that:\n\n"

View File

@ -276,6 +276,10 @@ class AutoTokenizer:
proxies (:obj:`Dict[str, str]`, `optional`): proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`): use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to try to load the fast version of the tokenizer. Whether or not to try to load the fast version of the tokenizer.
kwargs (additional keyword arguments, `optional`): kwargs (additional keyword arguments, `optional`):

View File

@ -29,6 +29,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import requests
from .file_utils import ( from .file_utils import (
add_end_docstrings, add_end_docstrings,
cached_path, cached_path,
@ -1515,6 +1517,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
proxies (:obj:`Dict[str, str], `optional`): proxies (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
inputs (additional positional arguments, `optional`): inputs (additional positional arguments, `optional`):
Will be passed along to the Tokenizer ``__init__`` method. Will be passed along to the Tokenizer ``__init__`` method.
kwargs (additional keyword arguments, `optional`): kwargs (additional keyword arguments, `optional`):
@ -1549,6 +1555,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
s3_models = list(cls.max_model_input_sizes.keys()) s3_models = list(cls.max_model_input_sizes.keys())
vocab_files = {} vocab_files = {}
@ -1601,18 +1608,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
full_file_name = None full_file_name = None
else: else:
full_file_name = hf_bucket_url( full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None pretrained_model_name_or_path, filename=file_name, revision=revision, mirror=None
) )
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
# Get files from url, cache, or disk depending on the case # Get files from url, cache, or disk depending on the case
try: resolved_vocab_files = {}
resolved_vocab_files = {} for file_id, file_path in vocab_files.items():
for file_id, file_path in vocab_files.items(): if file_path is None:
if file_path is None: resolved_vocab_files[file_id] = None
resolved_vocab_files[file_id] = None else:
else: try:
resolved_vocab_files[file_id] = cached_path( resolved_vocab_files[file_id] = cached_path(
file_path, file_path,
cache_dir=cache_dir, cache_dir=cache_dir,
@ -1621,34 +1628,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
resume_download=resume_download, resume_download=resume_download,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
except EnvironmentError: except requests.exceptions.HTTPError as err:
if pretrained_model_name_or_path in s3_models: if "404 Client Error" in str(err):
msg = "Couldn't reach server at '{}' to download vocabulary files." logger.debug(err)
else: resolved_vocab_files[file_id] = None
msg = ( else:
"Model name '{}' was not found in tokenizers model name list ({}). " raise err
"We assumed '{}' was a path or url to a directory containing vocabulary files "
"named {}, but couldn't find such vocabulary files at this path or url.".format(
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
)
raise EnvironmentError(msg)
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()): if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
raise EnvironmentError( msg = (
"Model name '{}' was not found in tokenizers model name list ({}). " f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files " f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
"named {} but couldn't find such vocabulary files at this path or url.".format( f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
pretrained_model_name_or_path,
", ".join(s3_models),
pretrained_model_name_or_path,
list(cls.vocab_files_names.values()),
)
) )
raise EnvironmentError(msg)
for file_id, file_path in vocab_files.items(): for file_id, file_path in vocab_files.items():
if file_path == resolved_vocab_files[file_id]: if file_path == resolved_vocab_files[file_id]:

63
tests/test_file_utils.py Normal file
View File

@ -0,0 +1,63 @@
import unittest
import requests
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
# An actual model hosted on huggingface.co
REVISION_ID_DEFAULT = "main"
# Default branch name
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
# One particular commit (not the top of `main`)
REVISION_ID_INVALID = "aaaaaaa"
# This commit does not exist, so we should 404.
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
# Sha-1 of config.json on the top of `main`, for checking purposes
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
class GetFromCacheTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = get_from_cache(url)
def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
_ = get_from_cache(url)
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
_ = get_from_cache(url)
def test_standard_object(self):
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_standard_object_rev(self):
# Same object, but different revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def test_lfs_object(self):
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))

View File

@ -20,7 +20,7 @@ import unittest
import requests import requests
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, RepoObj, S3Obj
USER = "__DUMMY_TRANSFORMERS_USER__" USER = "__DUMMY_TRANSFORMERS_USER__"
@ -35,6 +35,7 @@ FILES = [
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"), os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
), ),
] ]
REPO_NAME = "my-model-{}".format(int(time.time()))
ENDPOINT_STAGING = "https://moon-staging.huggingface.co" ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
@ -78,15 +79,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org") urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
self.assertIsInstance(urls, PresignedUrl) self.assertIsInstance(urls, PresignedUrl)
def test_presign_invalid(self):
try:
_ = self._api.presign(token=self._token, filename="non_nested.json")
except HTTPError as e:
self.assertIsNotNone(e.response.text)
self.assertTrue("Filename invalid" in e.response.text)
else:
self.fail("Expected an exception")
def test_presign(self): def test_presign(self):
for FILE_KEY, FILE_PATH in FILES: for FILE_KEY, FILE_PATH in FILES:
urls = self._api.presign(token=self._token, filename=FILE_KEY) urls = self._api.presign(token=self._token, filename=FILE_KEY)
@ -109,6 +101,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
o = objs[-1] o = objs[-1]
self.assertIsInstance(o, S3Obj) self.assertIsInstance(o, S3Obj)
def test_list_repos_objs(self):
objs = self._api.list_repos_objs(token=self._token)
self.assertIsInstance(objs, list)
if len(objs) > 0:
o = objs[-1]
self.assertIsInstance(o, RepoObj)
def test_create_and_delete_repo(self):
self._api.create_repo(token=self._token, name=REPO_NAME)
self._api.delete_repo(token=self._token, name=REPO_NAME)
class HfApiPublicTest(unittest.TestCase): class HfApiPublicTest(unittest.TestCase):
def test_staging_model_list(self): def test_staging_model_list(self):

View File

@ -323,7 +323,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
def test_custom_load_tf_weights(self): def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained( model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True "jplu/tiny-tf-bert-random", output_loading_info=True
) )
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"]) self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]: for layer in output_loading_info["missing_keys"]:

View File

@ -165,7 +165,7 @@ DUMMY_FUNCTION = {
def read_init(): def read_init():
""" Read the init and exctracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """ """ Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f: with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()