mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Model versioning (#8324)
* fix typo * rm use_cdn & references, and implement new hf_bucket_url * I'm pretty sure we don't need to `read` this file * same here * [BIG] file_utils.networking: do not gobble up errors anymore * Fix CI 😇 * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Tiny doc tweak * Add doc + pass kwarg everywhere * Add more tests and explain cc @sshleifer let me know if better Co-Authored-By: Sam Shleifer <sshleifer@gmail.com> * Also implement revision in pipelines In the case where we're passing a task name or a string model identifier * Fix CI 😇 * Fix CI * [hf_api] new methods + command line implem * make style * Final endpoints post-migration * Fix post-migration * Py3.6 compat cc @stefan-it Thank you @stas00 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
parent
4185b115d4
commit
70f622fab4
@ -12,8 +12,8 @@ inference: false
|
||||
|
||||
## Disclaimer
|
||||
|
||||
Due do it's immense size, `t5-11b` requires some special treatment.
|
||||
First, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
|
||||
**Before `transformers` v3.5.0**, due do its immense size, `t5-11b` required some special treatment.
|
||||
If you're using transformers `<= v3.4.0`, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
|
||||
|
||||
```python
|
||||
t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False)
|
||||
|
@ -56,7 +56,3 @@ cd -
|
||||
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
||||
# add/remove files as needed
|
||||
|
||||
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
|
||||
# So the only way to start using the new model sooner is either:
|
||||
# 1. download it to a local path and use that path as model_name
|
||||
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
|
||||
|
@ -44,7 +44,3 @@ cd -
|
||||
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
||||
# add/remove files as needed
|
||||
|
||||
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
|
||||
# So the only way to start using the new model sooner is either:
|
||||
# 1. download it to a local path and use that path as model_name
|
||||
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
|
||||
|
@ -55,7 +55,3 @@ cd -
|
||||
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
||||
# add/remove files as needed
|
||||
|
||||
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
|
||||
# So the only way to start using the new model sooner is either:
|
||||
# 1. download it to a local path and use that path as model_name
|
||||
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from getpass import getpass
|
||||
@ -21,8 +22,10 @@ class UserCommands(BaseTransformersCLICommand):
|
||||
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||
logout_parser = parser.add_parser("logout", help="Log out")
|
||||
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||
# s3
|
||||
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
|
||||
# s3_datasets (s3-based system)
|
||||
s3_parser = parser.add_parser(
|
||||
"s3_datasets", help="{ls, rm} Commands to interact with the files you upload on S3."
|
||||
)
|
||||
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
|
||||
ls_parser = s3_subparsers.add_parser("ls")
|
||||
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||
@ -31,17 +34,42 @@ class UserCommands(BaseTransformersCLICommand):
|
||||
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
|
||||
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
|
||||
# upload
|
||||
upload_parser = parser.add_parser("upload", help="Upload a model to S3.")
|
||||
upload_parser.add_argument(
|
||||
"path", type=str, help="Local path of the model folder or individual file to upload."
|
||||
)
|
||||
upload_parser = s3_subparsers.add_parser("upload", help="Upload a file to S3.")
|
||||
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
||||
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||
upload_parser.add_argument(
|
||||
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
||||
)
|
||||
upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||
# deprecated model upload
|
||||
upload_parser = parser.add_parser(
|
||||
"upload",
|
||||
help=(
|
||||
"Deprecated: used to be the way to upload a model to S3."
|
||||
" We now use a git-based system for storing models and other artifacts."
|
||||
" Use the `repo create` command instead."
|
||||
),
|
||||
)
|
||||
upload_parser.set_defaults(func=lambda args: DeprecatedUploadCommand(args))
|
||||
|
||||
# new system: git-based repo system
|
||||
repo_parser = parser.add_parser(
|
||||
"repo", help="{create, ls-files} Commands to interact with your huggingface.co repos."
|
||||
)
|
||||
repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
|
||||
ls_parser = repo_subparsers.add_parser("ls-files", help="List all your files on huggingface.co")
|
||||
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||
ls_parser.set_defaults(func=lambda args: ListReposObjsCommand(args))
|
||||
repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
|
||||
repo_create_parser.add_argument(
|
||||
"name",
|
||||
type=str,
|
||||
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
|
||||
)
|
||||
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
||||
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
|
||||
|
||||
|
||||
class ANSI:
|
||||
@ -51,6 +79,7 @@ class ANSI:
|
||||
|
||||
_bold = "\u001b[1m"
|
||||
_red = "\u001b[31m"
|
||||
_gray = "\u001b[90m"
|
||||
_reset = "\u001b[0m"
|
||||
|
||||
@classmethod
|
||||
@ -61,6 +90,27 @@ class ANSI:
|
||||
def red(cls, s):
|
||||
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset)
|
||||
|
||||
@classmethod
|
||||
def gray(cls, s):
|
||||
return "{}{}{}".format(cls._gray, s, cls._reset)
|
||||
|
||||
|
||||
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||
"""
|
||||
Inspired by:
|
||||
|
||||
- stackoverflow.com/a/8356620/593036
|
||||
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||
"""
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(row_format.format(*headers))
|
||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||
for row in rows:
|
||||
lines.append(row_format.format(*row))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class BaseUserCommand:
|
||||
def __init__(self, args):
|
||||
@ -124,22 +174,6 @@ class LogoutCommand(BaseUserCommand):
|
||||
|
||||
|
||||
class ListObjsCommand(BaseUserCommand):
|
||||
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||
"""
|
||||
Inspired by:
|
||||
|
||||
- stackoverflow.com/a/8356620/593036
|
||||
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||
"""
|
||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||
lines = []
|
||||
lines.append(row_format.format(*headers))
|
||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||
for row in rows:
|
||||
lines.append(row_format.format(*row))
|
||||
return "\n".join(lines)
|
||||
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
@ -155,7 +189,7 @@ class ListObjsCommand(BaseUserCommand):
|
||||
print("No shared file yet")
|
||||
exit()
|
||||
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
||||
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||
|
||||
|
||||
class DeleteObjCommand(BaseUserCommand):
|
||||
@ -173,6 +207,85 @@ class DeleteObjCommand(BaseUserCommand):
|
||||
print("Done")
|
||||
|
||||
|
||||
class ListReposObjsCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
objs = self._api.list_repos_objs(token, organization=self.args.organization)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
exit(1)
|
||||
if len(objs) == 0:
|
||||
print("No shared file yet")
|
||||
exit()
|
||||
rows = [[obj.filename, obj.lastModified, obj.commit, obj.size] for obj in objs]
|
||||
print(tabulate(rows, headers=["Filename", "LastModified", "Commit-Sha", "Size"]))
|
||||
|
||||
|
||||
class RepoCreateCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
|
||||
print(ANSI.gray(stdout.strip()))
|
||||
except FileNotFoundError:
|
||||
print("Looks like you do not have git installed, please install.")
|
||||
|
||||
try:
|
||||
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
|
||||
print(ANSI.gray(stdout.strip()))
|
||||
except FileNotFoundError:
|
||||
print(
|
||||
ANSI.red(
|
||||
"Looks like you do not have git-lfs installed, please install."
|
||||
" You can install from https://git-lfs.github.com/."
|
||||
" Then run `git lfs install` (you only have to do this once)."
|
||||
)
|
||||
)
|
||||
print("")
|
||||
|
||||
user, _ = self._api.whoami(token)
|
||||
namespace = self.args.organization if self.args.organization is not None else user
|
||||
|
||||
print("You are about to create {}".format(ANSI.bold(namespace + "/" + self.args.name)))
|
||||
|
||||
if not self.args.yes:
|
||||
choice = input("Proceed? [Y/n] ").lower()
|
||||
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||
print("Abort")
|
||||
exit()
|
||||
try:
|
||||
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
exit(1)
|
||||
print("\nYour repo now lives at:")
|
||||
print(" {}".format(ANSI.bold(url)))
|
||||
print("\nYou can clone it locally with the command below," " and commit/push as usual.")
|
||||
print(f"\n git clone {url}")
|
||||
print("")
|
||||
|
||||
|
||||
class DeprecatedUploadCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
print(
|
||||
ANSI.red(
|
||||
"Deprecated: used to be the way to upload a model to S3."
|
||||
" We now use a git-based system for storing models and other artifacts."
|
||||
" Use the `repo create` command instead."
|
||||
)
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
class UploadCommand(BaseUserCommand):
|
||||
def walk_dir(self, rel_path):
|
||||
"""
|
||||
|
@ -289,6 +289,10 @@ class AutoConfig:
|
||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`False`, then this function returns just the final configuration object.
|
||||
|
||||
|
@ -311,6 +311,10 @@ class PretrainedConfig(object):
|
||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`False`, then this function returns just the final configuration object.
|
||||
|
||||
@ -362,6 +366,7 @@ class PretrainedConfig(object):
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
@ -369,7 +374,7 @@ class PretrainedConfig(object):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
|
||||
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
|
||||
)
|
||||
|
||||
try:
|
||||
@ -383,11 +388,10 @@ class PretrainedConfig(object):
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
# Load config dict
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
|
@ -4,6 +4,7 @@ https://github.com/allenai/allennlp Copyright by the AllenNLP authors.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -17,7 +18,7 @@ from dataclasses import fields
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
@ -217,6 +218,8 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
||||
|
||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
||||
|
||||
PRESET_MIRROR_DICT = {
|
||||
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||
@ -825,34 +828,37 @@ def is_remote_url(url_or_filename):
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
|
||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
|
||||
def hf_bucket_url(model_id: str, filename: str, revision: Optional[str] = None, mirror=None) -> str:
|
||||
"""
|
||||
Resolve a model identifier, and a file name, to a HF-hosted url on either S3 or Cloudfront (a Content Delivery
|
||||
Network, or CDN).
|
||||
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
|
||||
to Cloudfront (a Content Delivery Network, or CDN) for large files.
|
||||
|
||||
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
||||
bandwidth costs). However, it is more aggressively cached by default, so may not always reflect the latest changes
|
||||
to the underlying file (default TTL is 24 hours).
|
||||
bandwidth costs).
|
||||
|
||||
In terms of client-side caching from this library, even though Cloudfront relays the ETags from S3, using one or
|
||||
the other (or switching from one to the other) will affect caching: cached files are not shared between the two
|
||||
because the cached file's name contains a hash of the url.
|
||||
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
|
||||
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
|
||||
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
|
||||
can't ever be stale.
|
||||
|
||||
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
|
||||
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
|
||||
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
|
||||
"""
|
||||
endpoint = (
|
||||
PRESET_MIRROR_DICT.get(mirror, mirror)
|
||||
if mirror
|
||||
else CLOUDFRONT_DISTRIB_PREFIX
|
||||
if use_cdn
|
||||
else S3_BUCKET_PREFIX
|
||||
)
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{endpoint}/{model_id}/{filename}"
|
||||
if mirror:
|
||||
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
||||
legacy_format = "/" not in model_id
|
||||
if legacy_format:
|
||||
return f"{endpoint}/{model_id}-{filename}"
|
||||
else:
|
||||
return f"{endpoint}/{model_id}/{filename}"
|
||||
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
||||
|
||||
|
||||
def url_to_filename(url, etag=None):
|
||||
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
||||
"""
|
||||
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
||||
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
||||
@ -860,13 +866,11 @@ def url_to_filename(url, etag=None):
|
||||
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
||||
"""
|
||||
url_bytes = url.encode("utf-8")
|
||||
url_hash = sha256(url_bytes)
|
||||
filename = url_hash.hexdigest()
|
||||
filename = sha256(url_bytes).hexdigest()
|
||||
|
||||
if etag:
|
||||
etag_bytes = etag.encode("utf-8")
|
||||
etag_hash = sha256(etag_bytes)
|
||||
filename += "." + etag_hash.hexdigest()
|
||||
filename += "." + sha256(etag_bytes).hexdigest()
|
||||
|
||||
if url.endswith(".h5"):
|
||||
filename += ".h5"
|
||||
@ -927,8 +931,10 @@ def cached_path(
|
||||
re-extract the archive and override the folder where it was extracted.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
|
||||
otherwise
|
||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||
|
||||
Raises:
|
||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@ -992,7 +998,10 @@ def cached_path(
|
||||
return output_path
|
||||
|
||||
|
||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
Formats a user-agent string with basic info about a request.
|
||||
"""
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if is_torch_available():
|
||||
ua += "; torch/{}".format(torch.__version__)
|
||||
@ -1002,13 +1011,19 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
headers = {"user-agent": ua}
|
||||
return ua
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
||||
"""
|
||||
Donwload remote file. Do not gobble up errors.
|
||||
"""
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
if resume_size > 0:
|
||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get("Content-Length")
|
||||
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||
r.raise_for_status()
|
||||
content_length = r.headers.get("Content-Length")
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
progress = tqdm(
|
||||
unit="B",
|
||||
@ -1018,7 +1033,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
||||
desc="Downloading",
|
||||
disable=bool(logging.get_verbosity() == logging.NOTSET),
|
||||
)
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
for chunk in r.iter_content(chunk_size=1024):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
@ -1026,7 +1041,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
||||
|
||||
|
||||
def get_from_cache(
|
||||
url,
|
||||
url: str,
|
||||
cache_dir=None,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
@ -1040,8 +1055,10 @@ def get_from_cache(
|
||||
path to the cached file.
|
||||
|
||||
Return:
|
||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
|
||||
otherwise
|
||||
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||
|
||||
Raises:
|
||||
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = TRANSFORMERS_CACHE
|
||||
@ -1050,13 +1067,28 @@ def get_from_cache(
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
url_to_download = url
|
||||
etag = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
||||
if response.status_code == 200:
|
||||
etag = response.headers.get("ETag")
|
||||
except (EnvironmentError, requests.exceptions.Timeout):
|
||||
headers = {"user-agent": http_user_agent(user_agent)}
|
||||
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||
r.raise_for_status()
|
||||
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||
# We favor a custom header indicating the etag of the linked resource, and
|
||||
# we fallback to the regular etag header.
|
||||
# If we don't have any of those, raise an error.
|
||||
if etag is None:
|
||||
raise OSError(
|
||||
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
||||
)
|
||||
# In case of a redirect,
|
||||
# save an extra redirect on the request.get call,
|
||||
# and ensure we download the exact atomic version even if it changed
|
||||
# between the HEAD and the GET (unlikely, but hey).
|
||||
if 300 <= r.status_code <= 399:
|
||||
url_to_download = r.headers["Location"]
|
||||
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
||||
# etag is already None
|
||||
pass
|
||||
|
||||
@ -1065,7 +1097,7 @@ def get_from_cache(
|
||||
# get cache path to put the file
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
|
||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
||||
# etag is None == we don't have a connection or we passed local_files_only.
|
||||
# try to get the last downloaded one
|
||||
if etag is None:
|
||||
if os.path.exists(cache_path):
|
||||
@ -1088,7 +1120,11 @@ def get_from_cache(
|
||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||
" to False."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Connection error, and we cannot find the requested files in the cached path."
|
||||
" Please try again or make sure your Internet connection is on."
|
||||
)
|
||||
|
||||
# From now on, etag is not None.
|
||||
if os.path.exists(cache_path) and not force_download:
|
||||
@ -1107,8 +1143,8 @@ def get_from_cache(
|
||||
incomplete_path = cache_path + ".incomplete"
|
||||
|
||||
@contextmanager
|
||||
def _resumable_file_manager():
|
||||
with open(incomplete_path, "a+b") as f:
|
||||
def _resumable_file_manager() -> "io.BufferedWriter":
|
||||
with open(incomplete_path, "ab") as f:
|
||||
yield f
|
||||
|
||||
temp_file_manager = _resumable_file_manager
|
||||
@ -1117,7 +1153,7 @@ def get_from_cache(
|
||||
else:
|
||||
resume_size = 0
|
||||
else:
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
||||
resume_size = 0
|
||||
|
||||
# Download to temporary file, then copy to cache dir once finished.
|
||||
@ -1125,7 +1161,7 @@ def get_from_cache(
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||
|
||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||
|
||||
logger.info("storing %s in cache at %s", url, cache_path)
|
||||
os.replace(temp_file.name, cache_path)
|
||||
|
@ -27,9 +27,21 @@ import requests
|
||||
ENDPOINT = "https://huggingface.co"
|
||||
|
||||
|
||||
class RepoObj:
|
||||
"""
|
||||
HuggingFace git-based system, data structure that represents a file belonging to the current user.
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
|
||||
self.filename = filename
|
||||
self.lastModified = lastModified
|
||||
self.commit = commit
|
||||
self.size = size
|
||||
|
||||
|
||||
class S3Obj:
|
||||
"""
|
||||
Data structure that represents a file belonging to the current user.
|
||||
HuggingFace S3-based system, data structure that represents a file belonging to the current user.
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
||||
@ -46,38 +58,25 @@ class PresignedUrl:
|
||||
self.type = type # mime-type to send to S3.
|
||||
|
||||
|
||||
class S3Object:
|
||||
class ModelSibling:
|
||||
"""
|
||||
Data structure that represents a public file accessible on our S3.
|
||||
Data structure that represents a public file inside a model, accessible from huggingface.co
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str, # S3 object key
|
||||
etag: str,
|
||||
lastModified: str,
|
||||
size: int,
|
||||
rfilename: str, # filename relative to config.json
|
||||
**kwargs
|
||||
):
|
||||
self.key = key
|
||||
self.etag = etag
|
||||
self.lastModified = lastModified
|
||||
self.size = size
|
||||
self.rfilename = rfilename
|
||||
def __init__(self, rfilename: str, **kwargs):
|
||||
self.rfilename = rfilename # filename relative to the model root
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
"""
|
||||
Info about a public model accessible from our S3.
|
||||
Info about a public model accessible from huggingface.co
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
modelId: str, # id of model
|
||||
key: str, # S3 object key of config.json
|
||||
modelId: Optional[str] = None, # id of model
|
||||
author: Optional[str] = None,
|
||||
downloads: Optional[int] = None,
|
||||
tags: List[str] = [],
|
||||
@ -86,12 +85,11 @@ class ModelInfo:
|
||||
**kwargs
|
||||
):
|
||||
self.modelId = modelId
|
||||
self.key = key
|
||||
self.author = author
|
||||
self.downloads = downloads
|
||||
self.tags = tags
|
||||
self.pipeline_tag = pipeline_tag
|
||||
self.siblings = [S3Object(**x) for x in siblings] if siblings is not None else None
|
||||
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
@ -134,9 +132,11 @@ class HfApi:
|
||||
|
||||
def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
|
||||
"""
|
||||
HuggingFace S3-based system, used for datasets and metrics.
|
||||
|
||||
Call HF API to get a presigned url to upload `filename` to S3.
|
||||
"""
|
||||
path = "{}/api/presign".format(self.endpoint)
|
||||
path = "{}/api/datasets/presign".format(self.endpoint)
|
||||
r = requests.post(
|
||||
path,
|
||||
headers={"authorization": "Bearer {}".format(token)},
|
||||
@ -148,6 +148,8 @@ class HfApi:
|
||||
|
||||
def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
|
||||
"""
|
||||
HuggingFace S3-based system, used for datasets and metrics.
|
||||
|
||||
Get a presigned url, then upload file to S3.
|
||||
|
||||
Outputs: url: Read-only url for the stored file on S3.
|
||||
@ -169,9 +171,11 @@ class HfApi:
|
||||
|
||||
def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
|
||||
"""
|
||||
HuggingFace S3-based system, used for datasets and metrics.
|
||||
|
||||
Call HF API to list all stored files for user (or one of their organizations).
|
||||
"""
|
||||
path = "{}/api/listObjs".format(self.endpoint)
|
||||
path = "{}/api/datasets/listObjs".format(self.endpoint)
|
||||
params = {"organization": organization} if organization is not None else None
|
||||
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
|
||||
r.raise_for_status()
|
||||
@ -180,9 +184,11 @@ class HfApi:
|
||||
|
||||
def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
|
||||
"""
|
||||
HuggingFace S3-based system, used for datasets and metrics.
|
||||
|
||||
Call HF API to delete a file stored by user
|
||||
"""
|
||||
path = "{}/api/deleteObj".format(self.endpoint)
|
||||
path = "{}/api/datasets/deleteObj".format(self.endpoint)
|
||||
r = requests.delete(
|
||||
path,
|
||||
headers={"authorization": "Bearer {}".format(token)},
|
||||
@ -200,6 +206,51 @@ class HfApi:
|
||||
d = r.json()
|
||||
return [ModelInfo(**x) for x in d]
|
||||
|
||||
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
|
||||
"""
|
||||
HuggingFace git-based system, used for models.
|
||||
|
||||
Call HF API to list all stored files for user (or one of their organizations).
|
||||
"""
|
||||
path = "{}/api/repos/ls".format(self.endpoint)
|
||||
params = {"organization": organization} if organization is not None else None
|
||||
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return [RepoObj(**x) for x in d]
|
||||
|
||||
def create_repo(self, token: str, name: str, organization: Optional[str] = None) -> str:
|
||||
"""
|
||||
HuggingFace git-based system, used for models.
|
||||
|
||||
Call HF API to create a whole repo.
|
||||
"""
|
||||
path = "{}/api/repos/create".format(self.endpoint)
|
||||
r = requests.post(
|
||||
path,
|
||||
headers={"authorization": "Bearer {}".format(token)},
|
||||
json={"name": name, "organization": organization},
|
||||
)
|
||||
r.raise_for_status()
|
||||
d = r.json()
|
||||
return d["url"]
|
||||
|
||||
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
|
||||
"""
|
||||
HuggingFace git-based system, used for models.
|
||||
|
||||
Call HF API to delete a whole repo.
|
||||
|
||||
CAUTION(this is irreversible).
|
||||
"""
|
||||
path = "{}/api/repos/delete".format(self.endpoint)
|
||||
r = requests.delete(
|
||||
path,
|
||||
headers={"authorization": "Bearer {}".format(token)},
|
||||
json={"name": name, "organization": organization},
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
||||
class TqdmProgressFileReader:
|
||||
"""
|
||||
|
@ -144,9 +144,7 @@ class ModelCard:
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
model_card_file = pretrained_model_name_or_path
|
||||
else:
|
||||
model_card_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
|
||||
)
|
||||
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
|
||||
|
||||
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
|
||||
@ -156,8 +154,6 @@ class ModelCard:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
|
||||
if resolved_model_card_file is None:
|
||||
raise EnvironmentError
|
||||
if resolved_model_card_file == model_card_file:
|
||||
logger.info("loading model card file {}".format(model_card_file))
|
||||
else:
|
||||
|
@ -537,9 +537,10 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
|
@ -107,7 +107,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
# output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
@ -121,6 +121,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@ -131,7 +132,7 @@ class FlaxPreTrainedModel(ABC):
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
|
||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
@ -143,16 +144,13 @@ class FlaxPreTrainedModel(ABC):
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
|
||||
else:
|
||||
msg = (
|
||||
f"Model name '{pretrained_model_name_or_path}' "
|
||||
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
|
||||
f"We assumed '{archive_file}' was a path or url to model weight files but "
|
||||
"couldn't find any such file at this path or url."
|
||||
)
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
|
@ -420,9 +420,10 @@ TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||
|
@ -572,9 +572,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
@ -616,7 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
@ -631,6 +632,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@ -659,7 +661,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
|
||||
use_cdn=use_cdn,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
@ -673,9 +675,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
|
@ -813,9 +813,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
@ -857,7 +858,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
@ -872,6 +873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@ -909,7 +911,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path,
|
||||
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
||||
use_cdn=use_cdn,
|
||||
revision=revision,
|
||||
mirror=mirror,
|
||||
)
|
||||
|
||||
@ -923,9 +925,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
except EnvironmentError as err:
|
||||
logger.error(err)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
|
@ -86,7 +86,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_framework(model):
|
||||
def get_framework(model, revision: Optional[str] = None):
|
||||
"""
|
||||
Select framework (TensorFlow or PyTorch) to use.
|
||||
|
||||
@ -103,14 +103,14 @@ def get_framework(model):
|
||||
)
|
||||
if isinstance(model, str):
|
||||
if is_torch_available() and not is_tf_available():
|
||||
model = AutoModel.from_pretrained(model)
|
||||
model = AutoModel.from_pretrained(model, revision=revision)
|
||||
elif is_tf_available() and not is_torch_available():
|
||||
model = TFAutoModel.from_pretrained(model)
|
||||
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||
else:
|
||||
try:
|
||||
model = AutoModel.from_pretrained(model)
|
||||
model = AutoModel.from_pretrained(model, revision=revision)
|
||||
except OSError:
|
||||
model = TFAutoModel.from_pretrained(model)
|
||||
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||
|
||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||
return framework
|
||||
@ -2730,6 +2730,7 @@ def pipeline(
|
||||
config: Optional[Union[str, PretrainedConfig]] = None,
|
||||
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
||||
framework: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
use_fast: bool = False,
|
||||
**kwargs
|
||||
) -> Pipeline:
|
||||
@ -2784,6 +2785,10 @@ def pipeline(
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified and
|
||||
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
|
||||
is provided.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
When passing a task name or a string model identifier: The specific model version to use. It can be a
|
||||
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
|
||||
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
|
||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
|
||||
kwargs:
|
||||
@ -2845,17 +2850,19 @@ def pipeline(
|
||||
if isinstance(tokenizer, tuple):
|
||||
# For tuple we have (tokenizer name, {kwargs})
|
||||
use_fast = tokenizer[1].pop("use_fast", use_fast)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], use_fast=use_fast, **tokenizer[1])
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer[0], use_fast=use_fast, revision=revision, **tokenizer[1]
|
||||
)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=use_fast)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer, revision=revision, use_fast=use_fast)
|
||||
|
||||
# Instantiate config if needed
|
||||
if isinstance(config, str):
|
||||
config = AutoConfig.from_pretrained(config)
|
||||
config = AutoConfig.from_pretrained(config, revision=revision)
|
||||
|
||||
# Instantiate modelcard if needed
|
||||
if isinstance(modelcard, str):
|
||||
modelcard = ModelCard.from_pretrained(modelcard)
|
||||
modelcard = ModelCard.from_pretrained(modelcard, revision=revision)
|
||||
|
||||
# Instantiate model if needed
|
||||
if isinstance(model, str):
|
||||
@ -2873,7 +2880,7 @@ def pipeline(
|
||||
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
||||
"Trying to load the model with Tensorflow."
|
||||
)
|
||||
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
||||
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
|
||||
if task == "translation" and model.config.task_specific_params:
|
||||
for key in model.config.task_specific_params:
|
||||
if key.startswith("translation"):
|
||||
|
@ -125,8 +125,6 @@ class LegacyIndex(Index):
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(archive_file)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
msg = (
|
||||
f"Can't load '{archive_file}'. Make sure that:\n\n"
|
||||
|
@ -276,6 +276,10 @@ class AutoTokenizer:
|
||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to try to load the fast version of the tokenizer.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
|
@ -29,6 +29,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
|
||||
from .file_utils import (
|
||||
add_end_docstrings,
|
||||
cached_path,
|
||||
@ -1515,6 +1517,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
proxies (:obj:`Dict[str, str], `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
inputs (additional positional arguments, `optional`):
|
||||
Will be passed along to the Tokenizer ``__init__`` method.
|
||||
kwargs (additional keyword arguments, `optional`):
|
||||
@ -1549,6 +1555,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
s3_models = list(cls.max_model_input_sizes.keys())
|
||||
vocab_files = {}
|
||||
@ -1601,18 +1608,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
full_file_name = None
|
||||
else:
|
||||
full_file_name = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
|
||||
pretrained_model_name_or_path, filename=file_name, revision=revision, mirror=None
|
||||
)
|
||||
|
||||
vocab_files[file_id] = full_file_name
|
||||
|
||||
# Get files from url, cache, or disk depending on the case
|
||||
try:
|
||||
resolved_vocab_files = {}
|
||||
for file_id, file_path in vocab_files.items():
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
resolved_vocab_files = {}
|
||||
for file_id, file_path in vocab_files.items():
|
||||
if file_path is None:
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
try:
|
||||
resolved_vocab_files[file_id] = cached_path(
|
||||
file_path,
|
||||
cache_dir=cache_dir,
|
||||
@ -1621,34 +1628,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in s3_models:
|
||||
msg = "Couldn't reach server at '{}' to download vocabulary files."
|
||||
else:
|
||||
msg = (
|
||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
||||
"We assumed '{}' was a path or url to a directory containing vocabulary files "
|
||||
"named {}, but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(s3_models),
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values()),
|
||||
)
|
||||
)
|
||||
|
||||
raise EnvironmentError(msg)
|
||||
except requests.exceptions.HTTPError as err:
|
||||
if "404 Client Error" in str(err):
|
||||
logger.debug(err)
|
||||
resolved_vocab_files[file_id] = None
|
||||
else:
|
||||
raise err
|
||||
|
||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||
raise EnvironmentError(
|
||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
||||
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
|
||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(s3_models),
|
||||
pretrained_model_name_or_path,
|
||||
list(cls.vocab_files_names.values()),
|
||||
)
|
||||
msg = (
|
||||
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
for file_id, file_path in vocab_files.items():
|
||||
if file_path == resolved_vocab_files[file_id]:
|
||||
|
63
tests/test_file_utils.py
Normal file
63
tests/test_file_utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
|
||||
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
|
||||
|
||||
|
||||
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
|
||||
# An actual model hosted on huggingface.co
|
||||
|
||||
REVISION_ID_DEFAULT = "main"
|
||||
# Default branch name
|
||||
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
|
||||
# One particular commit (not the top of `main`)
|
||||
REVISION_ID_INVALID = "aaaaaaa"
|
||||
# This commit does not exist, so we should 404.
|
||||
|
||||
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
|
||||
# Sha-1 of config.json on the top of `main`, for checking purposes
|
||||
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
|
||||
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||
|
||||
|
||||
class GetFromCacheTests(unittest.TestCase):
|
||||
def test_bogus_url(self):
|
||||
# This lets us simulate no connection
|
||||
# as the error raised is the same
|
||||
# `ConnectionError`
|
||||
url = "https://bogus"
|
||||
with self.assertRaisesRegex(ValueError, "Connection error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_file_not_found(self):
|
||||
# Valid revision (None) but missing file.
|
||||
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_revision_not_found(self):
|
||||
# Valid file but missing revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||
_ = get_from_cache(url)
|
||||
|
||||
def test_standard_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
|
||||
|
||||
def test_standard_object_rev(self):
|
||||
# Same object, but different revision
|
||||
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
|
||||
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
|
||||
|
||||
def test_lfs_object(self):
|
||||
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
|
||||
filepath = get_from_cache(url, force_download=True)
|
||||
metadata = filename_to_url(filepath)
|
||||
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
@ -20,7 +20,7 @@ import unittest
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
|
||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, RepoObj, S3Obj
|
||||
|
||||
|
||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||
@ -35,6 +35,7 @@ FILES = [
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
||||
),
|
||||
]
|
||||
REPO_NAME = "my-model-{}".format(int(time.time()))
|
||||
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
|
||||
|
||||
|
||||
@ -78,15 +79,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
||||
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
|
||||
self.assertIsInstance(urls, PresignedUrl)
|
||||
|
||||
def test_presign_invalid(self):
|
||||
try:
|
||||
_ = self._api.presign(token=self._token, filename="non_nested.json")
|
||||
except HTTPError as e:
|
||||
self.assertIsNotNone(e.response.text)
|
||||
self.assertTrue("Filename invalid" in e.response.text)
|
||||
else:
|
||||
self.fail("Expected an exception")
|
||||
|
||||
def test_presign(self):
|
||||
for FILE_KEY, FILE_PATH in FILES:
|
||||
urls = self._api.presign(token=self._token, filename=FILE_KEY)
|
||||
@ -109,6 +101,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
||||
o = objs[-1]
|
||||
self.assertIsInstance(o, S3Obj)
|
||||
|
||||
def test_list_repos_objs(self):
|
||||
objs = self._api.list_repos_objs(token=self._token)
|
||||
self.assertIsInstance(objs, list)
|
||||
if len(objs) > 0:
|
||||
o = objs[-1]
|
||||
self.assertIsInstance(o, RepoObj)
|
||||
|
||||
def test_create_and_delete_repo(self):
|
||||
self._api.create_repo(token=self._token, name=REPO_NAME)
|
||||
self._api.delete_repo(token=self._token, name=REPO_NAME)
|
||||
|
||||
|
||||
class HfApiPublicTest(unittest.TestCase):
|
||||
def test_staging_model_list(self):
|
||||
|
@ -323,7 +323,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def test_custom_load_tf_weights(self):
|
||||
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
||||
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
|
||||
"jplu/tiny-tf-bert-random", output_loading_info=True
|
||||
)
|
||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
||||
for layer in output_loading_info["missing_keys"]:
|
||||
|
@ -165,7 +165,7 @@ DUMMY_FUNCTION = {
|
||||
|
||||
|
||||
def read_init():
|
||||
""" Read the init and exctracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user