mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Model versioning (#8324)
* fix typo * rm use_cdn & references, and implement new hf_bucket_url * I'm pretty sure we don't need to `read` this file * same here * [BIG] file_utils.networking: do not gobble up errors anymore * Fix CI 😇 * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Tiny doc tweak * Add doc + pass kwarg everywhere * Add more tests and explain cc @sshleifer let me know if better Co-Authored-By: Sam Shleifer <sshleifer@gmail.com> * Also implement revision in pipelines In the case where we're passing a task name or a string model identifier * Fix CI 😇 * Fix CI * [hf_api] new methods + command line implem * make style * Final endpoints post-migration * Fix post-migration * Py3.6 compat cc @stefan-it Thank you @stas00 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
parent
4185b115d4
commit
70f622fab4
@ -12,8 +12,8 @@ inference: false
|
|||||||
|
|
||||||
## Disclaimer
|
## Disclaimer
|
||||||
|
|
||||||
Due do it's immense size, `t5-11b` requires some special treatment.
|
**Before `transformers` v3.5.0**, due do its immense size, `t5-11b` required some special treatment.
|
||||||
First, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
|
If you're using transformers `<= v3.4.0`, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False)
|
t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False)
|
||||||
|
@ -56,7 +56,3 @@ cd -
|
|||||||
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
||||||
# add/remove files as needed
|
# add/remove files as needed
|
||||||
|
|
||||||
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
|
|
||||||
# So the only way to start using the new model sooner is either:
|
|
||||||
# 1. download it to a local path and use that path as model_name
|
|
||||||
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
|
|
||||||
|
@ -44,7 +44,3 @@ cd -
|
|||||||
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
||||||
# add/remove files as needed
|
# add/remove files as needed
|
||||||
|
|
||||||
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
|
|
||||||
# So the only way to start using the new model sooner is either:
|
|
||||||
# 1. download it to a local path and use that path as model_name
|
|
||||||
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
|
|
||||||
|
@ -55,7 +55,3 @@ cd -
|
|||||||
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
|
||||||
# add/remove files as needed
|
# add/remove files as needed
|
||||||
|
|
||||||
# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
|
|
||||||
# So the only way to start using the new model sooner is either:
|
|
||||||
# 1. download it to a local path and use that path as model_name
|
|
||||||
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
@ -21,8 +22,10 @@ class UserCommands(BaseTransformersCLICommand):
|
|||||||
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
||||||
logout_parser = parser.add_parser("logout", help="Log out")
|
logout_parser = parser.add_parser("logout", help="Log out")
|
||||||
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
||||||
# s3
|
# s3_datasets (s3-based system)
|
||||||
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
|
s3_parser = parser.add_parser(
|
||||||
|
"s3_datasets", help="{ls, rm} Commands to interact with the files you upload on S3."
|
||||||
|
)
|
||||||
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
|
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
|
||||||
ls_parser = s3_subparsers.add_parser("ls")
|
ls_parser = s3_subparsers.add_parser("ls")
|
||||||
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
@ -31,17 +34,42 @@ class UserCommands(BaseTransformersCLICommand):
|
|||||||
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
|
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
|
||||||
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
|
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
|
||||||
# upload
|
upload_parser = s3_subparsers.add_parser("upload", help="Upload a file to S3.")
|
||||||
upload_parser = parser.add_parser("upload", help="Upload a model to S3.")
|
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
||||||
upload_parser.add_argument(
|
|
||||||
"path", type=str, help="Local path of the model folder or individual file to upload."
|
|
||||||
)
|
|
||||||
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
upload_parser.add_argument(
|
upload_parser.add_argument(
|
||||||
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
||||||
)
|
)
|
||||||
upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
||||||
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
||||||
|
# deprecated model upload
|
||||||
|
upload_parser = parser.add_parser(
|
||||||
|
"upload",
|
||||||
|
help=(
|
||||||
|
"Deprecated: used to be the way to upload a model to S3."
|
||||||
|
" We now use a git-based system for storing models and other artifacts."
|
||||||
|
" Use the `repo create` command instead."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
upload_parser.set_defaults(func=lambda args: DeprecatedUploadCommand(args))
|
||||||
|
|
||||||
|
# new system: git-based repo system
|
||||||
|
repo_parser = parser.add_parser(
|
||||||
|
"repo", help="{create, ls-files} Commands to interact with your huggingface.co repos."
|
||||||
|
)
|
||||||
|
repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
|
||||||
|
ls_parser = repo_subparsers.add_parser("ls-files", help="List all your files on huggingface.co")
|
||||||
|
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
|
ls_parser.set_defaults(func=lambda args: ListReposObjsCommand(args))
|
||||||
|
repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
|
||||||
|
repo_create_parser.add_argument(
|
||||||
|
"name",
|
||||||
|
type=str,
|
||||||
|
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
|
||||||
|
)
|
||||||
|
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
||||||
|
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
||||||
|
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
|
||||||
|
|
||||||
|
|
||||||
class ANSI:
|
class ANSI:
|
||||||
@ -51,6 +79,7 @@ class ANSI:
|
|||||||
|
|
||||||
_bold = "\u001b[1m"
|
_bold = "\u001b[1m"
|
||||||
_red = "\u001b[31m"
|
_red = "\u001b[31m"
|
||||||
|
_gray = "\u001b[90m"
|
||||||
_reset = "\u001b[0m"
|
_reset = "\u001b[0m"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -61,6 +90,27 @@ class ANSI:
|
|||||||
def red(cls, s):
|
def red(cls, s):
|
||||||
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset)
|
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def gray(cls, s):
|
||||||
|
return "{}{}{}".format(cls._gray, s, cls._reset)
|
||||||
|
|
||||||
|
|
||||||
|
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Inspired by:
|
||||||
|
|
||||||
|
- stackoverflow.com/a/8356620/593036
|
||||||
|
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
||||||
|
"""
|
||||||
|
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
||||||
|
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
||||||
|
lines = []
|
||||||
|
lines.append(row_format.format(*headers))
|
||||||
|
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
||||||
|
for row in rows:
|
||||||
|
lines.append(row_format.format(*row))
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
class BaseUserCommand:
|
class BaseUserCommand:
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
@ -124,22 +174,6 @@ class LogoutCommand(BaseUserCommand):
|
|||||||
|
|
||||||
|
|
||||||
class ListObjsCommand(BaseUserCommand):
|
class ListObjsCommand(BaseUserCommand):
|
||||||
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
|
||||||
"""
|
|
||||||
Inspired by:
|
|
||||||
|
|
||||||
- stackoverflow.com/a/8356620/593036
|
|
||||||
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
|
||||||
"""
|
|
||||||
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
|
||||||
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
|
||||||
lines = []
|
|
||||||
lines.append(row_format.format(*headers))
|
|
||||||
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
|
||||||
for row in rows:
|
|
||||||
lines.append(row_format.format(*row))
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
token = HfFolder.get_token()
|
token = HfFolder.get_token()
|
||||||
if token is None:
|
if token is None:
|
||||||
@ -155,7 +189,7 @@ class ListObjsCommand(BaseUserCommand):
|
|||||||
print("No shared file yet")
|
print("No shared file yet")
|
||||||
exit()
|
exit()
|
||||||
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
||||||
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||||
|
|
||||||
|
|
||||||
class DeleteObjCommand(BaseUserCommand):
|
class DeleteObjCommand(BaseUserCommand):
|
||||||
@ -173,6 +207,85 @@ class DeleteObjCommand(BaseUserCommand):
|
|||||||
print("Done")
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
class ListReposObjsCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
print("Not logged in")
|
||||||
|
exit(1)
|
||||||
|
try:
|
||||||
|
objs = self._api.list_repos_objs(token, organization=self.args.organization)
|
||||||
|
except HTTPError as e:
|
||||||
|
print(e)
|
||||||
|
print(ANSI.red(e.response.text))
|
||||||
|
exit(1)
|
||||||
|
if len(objs) == 0:
|
||||||
|
print("No shared file yet")
|
||||||
|
exit()
|
||||||
|
rows = [[obj.filename, obj.lastModified, obj.commit, obj.size] for obj in objs]
|
||||||
|
print(tabulate(rows, headers=["Filename", "LastModified", "Commit-Sha", "Size"]))
|
||||||
|
|
||||||
|
|
||||||
|
class RepoCreateCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
token = HfFolder.get_token()
|
||||||
|
if token is None:
|
||||||
|
print("Not logged in")
|
||||||
|
exit(1)
|
||||||
|
try:
|
||||||
|
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
|
||||||
|
print(ANSI.gray(stdout.strip()))
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("Looks like you do not have git installed, please install.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
|
||||||
|
print(ANSI.gray(stdout.strip()))
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(
|
||||||
|
ANSI.red(
|
||||||
|
"Looks like you do not have git-lfs installed, please install."
|
||||||
|
" You can install from https://git-lfs.github.com/."
|
||||||
|
" Then run `git lfs install` (you only have to do this once)."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("")
|
||||||
|
|
||||||
|
user, _ = self._api.whoami(token)
|
||||||
|
namespace = self.args.organization if self.args.organization is not None else user
|
||||||
|
|
||||||
|
print("You are about to create {}".format(ANSI.bold(namespace + "/" + self.args.name)))
|
||||||
|
|
||||||
|
if not self.args.yes:
|
||||||
|
choice = input("Proceed? [Y/n] ").lower()
|
||||||
|
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||||
|
print("Abort")
|
||||||
|
exit()
|
||||||
|
try:
|
||||||
|
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization)
|
||||||
|
except HTTPError as e:
|
||||||
|
print(e)
|
||||||
|
print(ANSI.red(e.response.text))
|
||||||
|
exit(1)
|
||||||
|
print("\nYour repo now lives at:")
|
||||||
|
print(" {}".format(ANSI.bold(url)))
|
||||||
|
print("\nYou can clone it locally with the command below," " and commit/push as usual.")
|
||||||
|
print(f"\n git clone {url}")
|
||||||
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedUploadCommand(BaseUserCommand):
|
||||||
|
def run(self):
|
||||||
|
print(
|
||||||
|
ANSI.red(
|
||||||
|
"Deprecated: used to be the way to upload a model to S3."
|
||||||
|
" We now use a git-based system for storing models and other artifacts."
|
||||||
|
" Use the `repo create` command instead."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
class UploadCommand(BaseUserCommand):
|
class UploadCommand(BaseUserCommand):
|
||||||
def walk_dir(self, rel_path):
|
def walk_dir(self, rel_path):
|
||||||
"""
|
"""
|
||||||
|
@ -289,6 +289,10 @@ class AutoConfig:
|
|||||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
If :obj:`False`, then this function returns just the final configuration object.
|
If :obj:`False`, then this function returns just the final configuration object.
|
||||||
|
|
||||||
|
@ -311,6 +311,10 @@ class PretrainedConfig(object):
|
|||||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
If :obj:`False`, then this function returns just the final configuration object.
|
If :obj:`False`, then this function returns just the final configuration object.
|
||||||
|
|
||||||
@ -362,6 +366,7 @@ class PretrainedConfig(object):
|
|||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||||
@ -369,7 +374,7 @@ class PretrainedConfig(object):
|
|||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
config_file = hf_bucket_url(
|
config_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
|
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -383,11 +388,10 @@ class PretrainedConfig(object):
|
|||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
# Load config dict
|
# Load config dict
|
||||||
if resolved_config_file is None:
|
|
||||||
raise EnvironmentError
|
|
||||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||||
|
|
||||||
except EnvironmentError:
|
except EnvironmentError as err:
|
||||||
|
logger.error(err)
|
||||||
msg = (
|
msg = (
|
||||||
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||||
|
@ -4,6 +4,7 @@ https://github.com/allenai/allennlp Copyright by the AllenNLP authors.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -17,7 +18,7 @@ from dataclasses import fields
|
|||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
|
|
||||||
@ -217,6 +218,8 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
|||||||
|
|
||||||
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
||||||
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
||||||
|
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
||||||
|
|
||||||
PRESET_MIRROR_DICT = {
|
PRESET_MIRROR_DICT = {
|
||||||
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
||||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||||
@ -825,34 +828,37 @@ def is_remote_url(url_or_filename):
|
|||||||
return parsed.scheme in ("http", "https")
|
return parsed.scheme in ("http", "https")
|
||||||
|
|
||||||
|
|
||||||
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
|
def hf_bucket_url(model_id: str, filename: str, revision: Optional[str] = None, mirror=None) -> str:
|
||||||
"""
|
"""
|
||||||
Resolve a model identifier, and a file name, to a HF-hosted url on either S3 or Cloudfront (a Content Delivery
|
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
|
||||||
Network, or CDN).
|
to Cloudfront (a Content Delivery Network, or CDN) for large files.
|
||||||
|
|
||||||
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
||||||
bandwidth costs). However, it is more aggressively cached by default, so may not always reflect the latest changes
|
bandwidth costs).
|
||||||
to the underlying file (default TTL is 24 hours).
|
|
||||||
|
|
||||||
In terms of client-side caching from this library, even though Cloudfront relays the ETags from S3, using one or
|
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
|
||||||
the other (or switching from one to the other) will affect caching: cached files are not shared between the two
|
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
|
||||||
because the cached file's name contains a hash of the url.
|
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
|
||||||
|
can't ever be stale.
|
||||||
|
|
||||||
|
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
|
||||||
|
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
|
||||||
|
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
|
||||||
"""
|
"""
|
||||||
endpoint = (
|
if mirror:
|
||||||
PRESET_MIRROR_DICT.get(mirror, mirror)
|
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
||||||
if mirror
|
legacy_format = "/" not in model_id
|
||||||
else CLOUDFRONT_DISTRIB_PREFIX
|
if legacy_format:
|
||||||
if use_cdn
|
return f"{endpoint}/{model_id}-{filename}"
|
||||||
else S3_BUCKET_PREFIX
|
else:
|
||||||
)
|
return f"{endpoint}/{model_id}/{filename}"
|
||||||
legacy_format = "/" not in model_id
|
|
||||||
if legacy_format:
|
if revision is None:
|
||||||
return f"{endpoint}/{model_id}-{filename}"
|
revision = "main"
|
||||||
else:
|
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
||||||
return f"{endpoint}/{model_id}/{filename}"
|
|
||||||
|
|
||||||
|
|
||||||
def url_to_filename(url, etag=None):
|
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
||||||
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
||||||
@ -860,13 +866,11 @@ def url_to_filename(url, etag=None):
|
|||||||
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
||||||
"""
|
"""
|
||||||
url_bytes = url.encode("utf-8")
|
url_bytes = url.encode("utf-8")
|
||||||
url_hash = sha256(url_bytes)
|
filename = sha256(url_bytes).hexdigest()
|
||||||
filename = url_hash.hexdigest()
|
|
||||||
|
|
||||||
if etag:
|
if etag:
|
||||||
etag_bytes = etag.encode("utf-8")
|
etag_bytes = etag.encode("utf-8")
|
||||||
etag_hash = sha256(etag_bytes)
|
filename += "." + sha256(etag_bytes).hexdigest()
|
||||||
filename += "." + etag_hash.hexdigest()
|
|
||||||
|
|
||||||
if url.endswith(".h5"):
|
if url.endswith(".h5"):
|
||||||
filename += ".h5"
|
filename += ".h5"
|
||||||
@ -927,8 +931,10 @@ def cached_path(
|
|||||||
re-extract the archive and override the folder where it was extracted.
|
re-extract the archive and override the folder where it was extracted.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
|
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||||
otherwise
|
|
||||||
|
Raises:
|
||||||
|
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
@ -992,7 +998,10 @@ def cached_path(
|
|||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||||
|
"""
|
||||||
|
Formats a user-agent string with basic info about a request.
|
||||||
|
"""
|
||||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
ua += "; torch/{}".format(torch.__version__)
|
ua += "; torch/{}".format(torch.__version__)
|
||||||
@ -1002,13 +1011,19 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
|||||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||||
elif isinstance(user_agent, str):
|
elif isinstance(user_agent, str):
|
||||||
ua += "; " + user_agent
|
ua += "; " + user_agent
|
||||||
headers = {"user-agent": ua}
|
return ua
|
||||||
|
|
||||||
|
|
||||||
|
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
||||||
|
"""
|
||||||
|
Donwload remote file. Do not gobble up errors.
|
||||||
|
"""
|
||||||
|
headers = {"user-agent": http_user_agent(user_agent)}
|
||||||
if resume_size > 0:
|
if resume_size > 0:
|
||||||
headers["Range"] = "bytes=%d-" % (resume_size,)
|
headers["Range"] = "bytes=%d-" % (resume_size,)
|
||||||
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
||||||
if response.status_code == 416: # Range not satisfiable
|
r.raise_for_status()
|
||||||
return
|
content_length = r.headers.get("Content-Length")
|
||||||
content_length = response.headers.get("Content-Length")
|
|
||||||
total = resume_size + int(content_length) if content_length is not None else None
|
total = resume_size + int(content_length) if content_length is not None else None
|
||||||
progress = tqdm(
|
progress = tqdm(
|
||||||
unit="B",
|
unit="B",
|
||||||
@ -1018,7 +1033,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
|||||||
desc="Downloading",
|
desc="Downloading",
|
||||||
disable=bool(logging.get_verbosity() == logging.NOTSET),
|
disable=bool(logging.get_verbosity() == logging.NOTSET),
|
||||||
)
|
)
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
for chunk in r.iter_content(chunk_size=1024):
|
||||||
if chunk: # filter out keep-alive new chunks
|
if chunk: # filter out keep-alive new chunks
|
||||||
progress.update(len(chunk))
|
progress.update(len(chunk))
|
||||||
temp_file.write(chunk)
|
temp_file.write(chunk)
|
||||||
@ -1026,7 +1041,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
|||||||
|
|
||||||
|
|
||||||
def get_from_cache(
|
def get_from_cache(
|
||||||
url,
|
url: str,
|
||||||
cache_dir=None,
|
cache_dir=None,
|
||||||
force_download=False,
|
force_download=False,
|
||||||
proxies=None,
|
proxies=None,
|
||||||
@ -1040,8 +1055,10 @@ def get_from_cache(
|
|||||||
path to the cached file.
|
path to the cached file.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk). Local path (string)
|
Local path (string) of file or if networking is off, last version of file cached on disk.
|
||||||
otherwise
|
|
||||||
|
Raises:
|
||||||
|
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = TRANSFORMERS_CACHE
|
cache_dir = TRANSFORMERS_CACHE
|
||||||
@ -1050,13 +1067,28 @@ def get_from_cache(
|
|||||||
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
url_to_download = url
|
||||||
etag = None
|
etag = None
|
||||||
if not local_files_only:
|
if not local_files_only:
|
||||||
try:
|
try:
|
||||||
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
headers = {"user-agent": http_user_agent(user_agent)}
|
||||||
if response.status_code == 200:
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
||||||
etag = response.headers.get("ETag")
|
r.raise_for_status()
|
||||||
except (EnvironmentError, requests.exceptions.Timeout):
|
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
||||||
|
# We favor a custom header indicating the etag of the linked resource, and
|
||||||
|
# we fallback to the regular etag header.
|
||||||
|
# If we don't have any of those, raise an error.
|
||||||
|
if etag is None:
|
||||||
|
raise OSError(
|
||||||
|
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
||||||
|
)
|
||||||
|
# In case of a redirect,
|
||||||
|
# save an extra redirect on the request.get call,
|
||||||
|
# and ensure we download the exact atomic version even if it changed
|
||||||
|
# between the HEAD and the GET (unlikely, but hey).
|
||||||
|
if 300 <= r.status_code <= 399:
|
||||||
|
url_to_download = r.headers["Location"]
|
||||||
|
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
||||||
# etag is already None
|
# etag is already None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1065,7 +1097,7 @@ def get_from_cache(
|
|||||||
# get cache path to put the file
|
# get cache path to put the file
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
cache_path = os.path.join(cache_dir, filename)
|
||||||
|
|
||||||
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
# etag is None == we don't have a connection or we passed local_files_only.
|
||||||
# try to get the last downloaded one
|
# try to get the last downloaded one
|
||||||
if etag is None:
|
if etag is None:
|
||||||
if os.path.exists(cache_path):
|
if os.path.exists(cache_path):
|
||||||
@ -1088,7 +1120,11 @@ def get_from_cache(
|
|||||||
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
||||||
" to False."
|
" to False."
|
||||||
)
|
)
|
||||||
return None
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Connection error, and we cannot find the requested files in the cached path."
|
||||||
|
" Please try again or make sure your Internet connection is on."
|
||||||
|
)
|
||||||
|
|
||||||
# From now on, etag is not None.
|
# From now on, etag is not None.
|
||||||
if os.path.exists(cache_path) and not force_download:
|
if os.path.exists(cache_path) and not force_download:
|
||||||
@ -1107,8 +1143,8 @@ def get_from_cache(
|
|||||||
incomplete_path = cache_path + ".incomplete"
|
incomplete_path = cache_path + ".incomplete"
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _resumable_file_manager():
|
def _resumable_file_manager() -> "io.BufferedWriter":
|
||||||
with open(incomplete_path, "a+b") as f:
|
with open(incomplete_path, "ab") as f:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
temp_file_manager = _resumable_file_manager
|
temp_file_manager = _resumable_file_manager
|
||||||
@ -1117,7 +1153,7 @@ def get_from_cache(
|
|||||||
else:
|
else:
|
||||||
resume_size = 0
|
resume_size = 0
|
||||||
else:
|
else:
|
||||||
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
||||||
resume_size = 0
|
resume_size = 0
|
||||||
|
|
||||||
# Download to temporary file, then copy to cache dir once finished.
|
# Download to temporary file, then copy to cache dir once finished.
|
||||||
@ -1125,7 +1161,7 @@ def get_from_cache(
|
|||||||
with temp_file_manager() as temp_file:
|
with temp_file_manager() as temp_file:
|
||||||
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
||||||
|
|
||||||
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
||||||
|
|
||||||
logger.info("storing %s in cache at %s", url, cache_path)
|
logger.info("storing %s in cache at %s", url, cache_path)
|
||||||
os.replace(temp_file.name, cache_path)
|
os.replace(temp_file.name, cache_path)
|
||||||
|
@ -27,9 +27,21 @@ import requests
|
|||||||
ENDPOINT = "https://huggingface.co"
|
ENDPOINT = "https://huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
|
class RepoObj:
|
||||||
|
"""
|
||||||
|
HuggingFace git-based system, data structure that represents a file belonging to the current user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename: str, lastModified: str, commit: str, size: int, **kwargs):
|
||||||
|
self.filename = filename
|
||||||
|
self.lastModified = lastModified
|
||||||
|
self.commit = commit
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
|
||||||
class S3Obj:
|
class S3Obj:
|
||||||
"""
|
"""
|
||||||
Data structure that represents a file belonging to the current user.
|
HuggingFace S3-based system, data structure that represents a file belonging to the current user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
def __init__(self, filename: str, LastModified: str, ETag: str, Size: int, **kwargs):
|
||||||
@ -46,38 +58,25 @@ class PresignedUrl:
|
|||||||
self.type = type # mime-type to send to S3.
|
self.type = type # mime-type to send to S3.
|
||||||
|
|
||||||
|
|
||||||
class S3Object:
|
class ModelSibling:
|
||||||
"""
|
"""
|
||||||
Data structure that represents a public file accessible on our S3.
|
Data structure that represents a public file inside a model, accessible from huggingface.co
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, rfilename: str, **kwargs):
|
||||||
self,
|
self.rfilename = rfilename # filename relative to the model root
|
||||||
key: str, # S3 object key
|
|
||||||
etag: str,
|
|
||||||
lastModified: str,
|
|
||||||
size: int,
|
|
||||||
rfilename: str, # filename relative to config.json
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
self.key = key
|
|
||||||
self.etag = etag
|
|
||||||
self.lastModified = lastModified
|
|
||||||
self.size = size
|
|
||||||
self.rfilename = rfilename
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
"""
|
"""
|
||||||
Info about a public model accessible from our S3.
|
Info about a public model accessible from huggingface.co
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
modelId: str, # id of model
|
modelId: Optional[str] = None, # id of model
|
||||||
key: str, # S3 object key of config.json
|
|
||||||
author: Optional[str] = None,
|
author: Optional[str] = None,
|
||||||
downloads: Optional[int] = None,
|
downloads: Optional[int] = None,
|
||||||
tags: List[str] = [],
|
tags: List[str] = [],
|
||||||
@ -86,12 +85,11 @@ class ModelInfo:
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.modelId = modelId
|
self.modelId = modelId
|
||||||
self.key = key
|
|
||||||
self.author = author
|
self.author = author
|
||||||
self.downloads = downloads
|
self.downloads = downloads
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
self.pipeline_tag = pipeline_tag
|
self.pipeline_tag = pipeline_tag
|
||||||
self.siblings = [S3Object(**x) for x in siblings] if siblings is not None else None
|
self.siblings = [ModelSibling(**x) for x in siblings] if siblings is not None else None
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
@ -134,9 +132,11 @@ class HfApi:
|
|||||||
|
|
||||||
def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
|
def presign(self, token: str, filename: str, organization: Optional[str] = None) -> PresignedUrl:
|
||||||
"""
|
"""
|
||||||
|
HuggingFace S3-based system, used for datasets and metrics.
|
||||||
|
|
||||||
Call HF API to get a presigned url to upload `filename` to S3.
|
Call HF API to get a presigned url to upload `filename` to S3.
|
||||||
"""
|
"""
|
||||||
path = "{}/api/presign".format(self.endpoint)
|
path = "{}/api/datasets/presign".format(self.endpoint)
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
path,
|
path,
|
||||||
headers={"authorization": "Bearer {}".format(token)},
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
@ -148,6 +148,8 @@ class HfApi:
|
|||||||
|
|
||||||
def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
|
def presign_and_upload(self, token: str, filename: str, filepath: str, organization: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
|
HuggingFace S3-based system, used for datasets and metrics.
|
||||||
|
|
||||||
Get a presigned url, then upload file to S3.
|
Get a presigned url, then upload file to S3.
|
||||||
|
|
||||||
Outputs: url: Read-only url for the stored file on S3.
|
Outputs: url: Read-only url for the stored file on S3.
|
||||||
@ -169,9 +171,11 @@ class HfApi:
|
|||||||
|
|
||||||
def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
|
def list_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
|
||||||
"""
|
"""
|
||||||
|
HuggingFace S3-based system, used for datasets and metrics.
|
||||||
|
|
||||||
Call HF API to list all stored files for user (or one of their organizations).
|
Call HF API to list all stored files for user (or one of their organizations).
|
||||||
"""
|
"""
|
||||||
path = "{}/api/listObjs".format(self.endpoint)
|
path = "{}/api/datasets/listObjs".format(self.endpoint)
|
||||||
params = {"organization": organization} if organization is not None else None
|
params = {"organization": organization} if organization is not None else None
|
||||||
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
|
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
@ -180,9 +184,11 @@ class HfApi:
|
|||||||
|
|
||||||
def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
|
def delete_obj(self, token: str, filename: str, organization: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
|
HuggingFace S3-based system, used for datasets and metrics.
|
||||||
|
|
||||||
Call HF API to delete a file stored by user
|
Call HF API to delete a file stored by user
|
||||||
"""
|
"""
|
||||||
path = "{}/api/deleteObj".format(self.endpoint)
|
path = "{}/api/datasets/deleteObj".format(self.endpoint)
|
||||||
r = requests.delete(
|
r = requests.delete(
|
||||||
path,
|
path,
|
||||||
headers={"authorization": "Bearer {}".format(token)},
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
@ -200,6 +206,51 @@ class HfApi:
|
|||||||
d = r.json()
|
d = r.json()
|
||||||
return [ModelInfo(**x) for x in d]
|
return [ModelInfo(**x) for x in d]
|
||||||
|
|
||||||
|
def list_repos_objs(self, token: str, organization: Optional[str] = None) -> List[S3Obj]:
|
||||||
|
"""
|
||||||
|
HuggingFace git-based system, used for models.
|
||||||
|
|
||||||
|
Call HF API to list all stored files for user (or one of their organizations).
|
||||||
|
"""
|
||||||
|
path = "{}/api/repos/ls".format(self.endpoint)
|
||||||
|
params = {"organization": organization} if organization is not None else None
|
||||||
|
r = requests.get(path, params=params, headers={"authorization": "Bearer {}".format(token)})
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return [RepoObj(**x) for x in d]
|
||||||
|
|
||||||
|
def create_repo(self, token: str, name: str, organization: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
HuggingFace git-based system, used for models.
|
||||||
|
|
||||||
|
Call HF API to create a whole repo.
|
||||||
|
"""
|
||||||
|
path = "{}/api/repos/create".format(self.endpoint)
|
||||||
|
r = requests.post(
|
||||||
|
path,
|
||||||
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
|
json={"name": name, "organization": organization},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
d = r.json()
|
||||||
|
return d["url"]
|
||||||
|
|
||||||
|
def delete_repo(self, token: str, name: str, organization: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
HuggingFace git-based system, used for models.
|
||||||
|
|
||||||
|
Call HF API to delete a whole repo.
|
||||||
|
|
||||||
|
CAUTION(this is irreversible).
|
||||||
|
"""
|
||||||
|
path = "{}/api/repos/delete".format(self.endpoint)
|
||||||
|
r = requests.delete(
|
||||||
|
path,
|
||||||
|
headers={"authorization": "Bearer {}".format(token)},
|
||||||
|
json={"name": name, "organization": organization},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
|
||||||
class TqdmProgressFileReader:
|
class TqdmProgressFileReader:
|
||||||
"""
|
"""
|
||||||
|
@ -144,9 +144,7 @@ class ModelCard:
|
|||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
model_card_file = pretrained_model_name_or_path
|
model_card_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
model_card_file = hf_bucket_url(
|
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
|
||||||
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
|
|
||||||
)
|
|
||||||
|
|
||||||
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||||
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
|
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
|
||||||
@ -156,8 +154,6 @@ class ModelCard:
|
|||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
|
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
|
||||||
if resolved_model_card_file is None:
|
|
||||||
raise EnvironmentError
|
|
||||||
if resolved_model_card_file == model_card_file:
|
if resolved_model_card_file == model_card_file:
|
||||||
logger.info("loading model card file {}".format(model_card_file))
|
logger.info("loading model card file {}".format(model_card_file))
|
||||||
else:
|
else:
|
||||||
|
@ -537,9 +537,10 @@ AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
|||||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
kwargs (additional keyword arguments, `optional`):
|
kwargs (additional keyword arguments, `optional`):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
|
@ -107,7 +107,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
# output_loading_info = kwargs.pop("output_loading_info", False)
|
# output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
use_cdn = kwargs.pop("use_cdn", True)
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
@ -121,6 +121,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -131,7 +132,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
@ -143,16 +144,13 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except EnvironmentError as err:
|
||||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
logger.error(err)
|
||||||
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
|
msg = (
|
||||||
else:
|
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||||
msg = (
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||||
f"Model name '{pretrained_model_name_or_path}' "
|
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
||||||
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
|
)
|
||||||
f"We assumed '{archive_file}' was a path or url to model weight files but "
|
|
||||||
"couldn't find any such file at this path or url."
|
|
||||||
)
|
|
||||||
raise EnvironmentError(msg)
|
raise EnvironmentError(msg)
|
||||||
|
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
|
@ -420,9 +420,10 @@ TF_AUTO_MODEL_PRETRAINED_DOCSTRING = r"""
|
|||||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to only look at local files (e.g., not try downloading the model).
|
Whether or not to only look at local files (e.g., not try downloading the model).
|
||||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
kwargs (additional keyword arguments, `optional`):
|
kwargs (additional keyword arguments, `optional`):
|
||||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
|
@ -572,9 +572,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
@ -616,7 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
use_cdn = kwargs.pop("use_cdn", True)
|
revision = kwargs.pop("revision", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
mirror = kwargs.pop("mirror", None)
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
@ -631,6 +632,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -659,7 +661,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
|
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
|
||||||
use_cdn=use_cdn,
|
revision=revision,
|
||||||
mirror=mirror,
|
mirror=mirror,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -673,9 +675,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
if resolved_archive_file is None:
|
except EnvironmentError as err:
|
||||||
raise EnvironmentError
|
logger.error(err)
|
||||||
except EnvironmentError:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||||
|
@ -813,9 +813,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||||
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to only look at local files (e.g., not try doanloading the model).
|
Whether or not to only look at local files (e.g., not try doanloading the model).
|
||||||
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||||
@ -857,7 +858,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
use_cdn = kwargs.pop("use_cdn", True)
|
revision = kwargs.pop("revision", None)
|
||||||
mirror = kwargs.pop("mirror", None)
|
mirror = kwargs.pop("mirror", None)
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
@ -872,6 +873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -909,7 +911,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
||||||
use_cdn=use_cdn,
|
revision=revision,
|
||||||
mirror=mirror,
|
mirror=mirror,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -923,9 +925,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
if resolved_archive_file is None:
|
except EnvironmentError as err:
|
||||||
raise EnvironmentError
|
logger.error(err)
|
||||||
except EnvironmentError:
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||||
|
@ -86,7 +86,7 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_framework(model):
|
def get_framework(model, revision: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Select framework (TensorFlow or PyTorch) to use.
|
Select framework (TensorFlow or PyTorch) to use.
|
||||||
|
|
||||||
@ -103,14 +103,14 @@ def get_framework(model):
|
|||||||
)
|
)
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
if is_torch_available() and not is_tf_available():
|
if is_torch_available() and not is_tf_available():
|
||||||
model = AutoModel.from_pretrained(model)
|
model = AutoModel.from_pretrained(model, revision=revision)
|
||||||
elif is_tf_available() and not is_torch_available():
|
elif is_tf_available() and not is_torch_available():
|
||||||
model = TFAutoModel.from_pretrained(model)
|
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
model = AutoModel.from_pretrained(model)
|
model = AutoModel.from_pretrained(model, revision=revision)
|
||||||
except OSError:
|
except OSError:
|
||||||
model = TFAutoModel.from_pretrained(model)
|
model = TFAutoModel.from_pretrained(model, revision=revision)
|
||||||
|
|
||||||
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
framework = "tf" if model.__class__.__name__.startswith("TF") else "pt"
|
||||||
return framework
|
return framework
|
||||||
@ -2730,6 +2730,7 @@ def pipeline(
|
|||||||
config: Optional[Union[str, PretrainedConfig]] = None,
|
config: Optional[Union[str, PretrainedConfig]] = None,
|
||||||
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
|
||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
|
revision: Optional[str] = None,
|
||||||
use_fast: bool = False,
|
use_fast: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Pipeline:
|
) -> Pipeline:
|
||||||
@ -2784,6 +2785,10 @@ def pipeline(
|
|||||||
If no framework is specified, will default to the one currently installed. If no framework is specified and
|
If no framework is specified, will default to the one currently installed. If no framework is specified and
|
||||||
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
|
both frameworks are installed, will default to the framework of the :obj:`model`, or to PyTorch if no model
|
||||||
is provided.
|
is provided.
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
When passing a task name or a string model identifier: The specific model version to use. It can be a
|
||||||
|
branch name, a tag name, or a commit id, since we use a git-based system for storing models and other
|
||||||
|
artifacts on huggingface.co, so ``revision`` can be any identifier allowed by git.
|
||||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
|
Whether or not to use a Fast tokenizer if possible (a :class:`~transformers.PreTrainedTokenizerFast`).
|
||||||
kwargs:
|
kwargs:
|
||||||
@ -2845,17 +2850,19 @@ def pipeline(
|
|||||||
if isinstance(tokenizer, tuple):
|
if isinstance(tokenizer, tuple):
|
||||||
# For tuple we have (tokenizer name, {kwargs})
|
# For tuple we have (tokenizer name, {kwargs})
|
||||||
use_fast = tokenizer[1].pop("use_fast", use_fast)
|
use_fast = tokenizer[1].pop("use_fast", use_fast)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], use_fast=use_fast, **tokenizer[1])
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer[0], use_fast=use_fast, revision=revision, **tokenizer[1]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer, revision=revision, use_fast=use_fast)
|
||||||
|
|
||||||
# Instantiate config if needed
|
# Instantiate config if needed
|
||||||
if isinstance(config, str):
|
if isinstance(config, str):
|
||||||
config = AutoConfig.from_pretrained(config)
|
config = AutoConfig.from_pretrained(config, revision=revision)
|
||||||
|
|
||||||
# Instantiate modelcard if needed
|
# Instantiate modelcard if needed
|
||||||
if isinstance(modelcard, str):
|
if isinstance(modelcard, str):
|
||||||
modelcard = ModelCard.from_pretrained(modelcard)
|
modelcard = ModelCard.from_pretrained(modelcard, revision=revision)
|
||||||
|
|
||||||
# Instantiate model if needed
|
# Instantiate model if needed
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
@ -2873,7 +2880,7 @@ def pipeline(
|
|||||||
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
||||||
"Trying to load the model with Tensorflow."
|
"Trying to load the model with Tensorflow."
|
||||||
)
|
)
|
||||||
model = model_class.from_pretrained(model, config=config, **model_kwargs)
|
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
|
||||||
if task == "translation" and model.config.task_specific_params:
|
if task == "translation" and model.config.task_specific_params:
|
||||||
for key in model.config.task_specific_params:
|
for key in model.config.task_specific_params:
|
||||||
if key.startswith("translation"):
|
if key.startswith("translation"):
|
||||||
|
@ -125,8 +125,6 @@ class LegacyIndex(Index):
|
|||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_archive_file = cached_path(archive_file)
|
resolved_archive_file = cached_path(archive_file)
|
||||||
if resolved_archive_file is None:
|
|
||||||
raise EnvironmentError
|
|
||||||
except EnvironmentError:
|
except EnvironmentError:
|
||||||
msg = (
|
msg = (
|
||||||
f"Can't load '{archive_file}'. Make sure that:\n\n"
|
f"Can't load '{archive_file}'. Make sure that:\n\n"
|
||||||
|
@ -276,6 +276,10 @@ class AutoTokenizer:
|
|||||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
use_fast (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to try to load the fast version of the tokenizer.
|
Whether or not to try to load the fast version of the tokenizer.
|
||||||
kwargs (additional keyword arguments, `optional`):
|
kwargs (additional keyword arguments, `optional`):
|
||||||
|
@ -29,6 +29,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
cached_path,
|
cached_path,
|
||||||
@ -1515,6 +1517,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
proxies (:obj:`Dict[str, str], `optional`):
|
proxies (:obj:`Dict[str, str], `optional`):
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
inputs (additional positional arguments, `optional`):
|
inputs (additional positional arguments, `optional`):
|
||||||
Will be passed along to the Tokenizer ``__init__`` method.
|
Will be passed along to the Tokenizer ``__init__`` method.
|
||||||
kwargs (additional keyword arguments, `optional`):
|
kwargs (additional keyword arguments, `optional`):
|
||||||
@ -1549,6 +1555,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
s3_models = list(cls.max_model_input_sizes.keys())
|
s3_models = list(cls.max_model_input_sizes.keys())
|
||||||
vocab_files = {}
|
vocab_files = {}
|
||||||
@ -1601,18 +1608,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
full_file_name = None
|
full_file_name = None
|
||||||
else:
|
else:
|
||||||
full_file_name = hf_bucket_url(
|
full_file_name = hf_bucket_url(
|
||||||
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
|
pretrained_model_name_or_path, filename=file_name, revision=revision, mirror=None
|
||||||
)
|
)
|
||||||
|
|
||||||
vocab_files[file_id] = full_file_name
|
vocab_files[file_id] = full_file_name
|
||||||
|
|
||||||
# Get files from url, cache, or disk depending on the case
|
# Get files from url, cache, or disk depending on the case
|
||||||
try:
|
resolved_vocab_files = {}
|
||||||
resolved_vocab_files = {}
|
for file_id, file_path in vocab_files.items():
|
||||||
for file_id, file_path in vocab_files.items():
|
if file_path is None:
|
||||||
if file_path is None:
|
resolved_vocab_files[file_id] = None
|
||||||
resolved_vocab_files[file_id] = None
|
else:
|
||||||
else:
|
try:
|
||||||
resolved_vocab_files[file_id] = cached_path(
|
resolved_vocab_files[file_id] = cached_path(
|
||||||
file_path,
|
file_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
@ -1621,34 +1628,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except requests.exceptions.HTTPError as err:
|
||||||
if pretrained_model_name_or_path in s3_models:
|
if "404 Client Error" in str(err):
|
||||||
msg = "Couldn't reach server at '{}' to download vocabulary files."
|
logger.debug(err)
|
||||||
else:
|
resolved_vocab_files[file_id] = None
|
||||||
msg = (
|
else:
|
||||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
raise err
|
||||||
"We assumed '{}' was a path or url to a directory containing vocabulary files "
|
|
||||||
"named {}, but couldn't find such vocabulary files at this path or url.".format(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
", ".join(s3_models),
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
list(cls.vocab_files_names.values()),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
raise EnvironmentError(msg)
|
|
||||||
|
|
||||||
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
|
||||||
raise EnvironmentError(
|
msg = (
|
||||||
"Model name '{}' was not found in tokenizers model name list ({}). "
|
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||||
"We assumed '{}' was a path, a model identifier, or url to a directory containing vocabulary files "
|
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||||
"named {} but couldn't find such vocabulary files at this path or url.".format(
|
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
|
||||||
pretrained_model_name_or_path,
|
|
||||||
", ".join(s3_models),
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
list(cls.vocab_files_names.values()),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
raise EnvironmentError(msg)
|
||||||
|
|
||||||
for file_id, file_path in vocab_files.items():
|
for file_id, file_path in vocab_files.items():
|
||||||
if file_path == resolved_vocab_files[file_id]:
|
if file_path == resolved_vocab_files[file_id]:
|
||||||
|
63
tests/test_file_utils.py
Normal file
63
tests/test_file_utils.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
|
||||||
|
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_ID = DUMMY_UNKWOWN_IDENTIFIER
|
||||||
|
# An actual model hosted on huggingface.co
|
||||||
|
|
||||||
|
REVISION_ID_DEFAULT = "main"
|
||||||
|
# Default branch name
|
||||||
|
REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
|
||||||
|
# One particular commit (not the top of `main`)
|
||||||
|
REVISION_ID_INVALID = "aaaaaaa"
|
||||||
|
# This commit does not exist, so we should 404.
|
||||||
|
|
||||||
|
PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
|
||||||
|
# Sha-1 of config.json on the top of `main`, for checking purposes
|
||||||
|
PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
|
||||||
|
# Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
|
||||||
|
|
||||||
|
|
||||||
|
class GetFromCacheTests(unittest.TestCase):
|
||||||
|
def test_bogus_url(self):
|
||||||
|
# This lets us simulate no connection
|
||||||
|
# as the error raised is the same
|
||||||
|
# `ConnectionError`
|
||||||
|
url = "https://bogus"
|
||||||
|
with self.assertRaisesRegex(ValueError, "Connection error"):
|
||||||
|
_ = get_from_cache(url)
|
||||||
|
|
||||||
|
def test_file_not_found(self):
|
||||||
|
# Valid revision (None) but missing file.
|
||||||
|
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
|
||||||
|
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||||
|
_ = get_from_cache(url)
|
||||||
|
|
||||||
|
def test_revision_not_found(self):
|
||||||
|
# Valid file but missing revision
|
||||||
|
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
|
||||||
|
with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"):
|
||||||
|
_ = get_from_cache(url)
|
||||||
|
|
||||||
|
def test_standard_object(self):
|
||||||
|
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
|
||||||
|
filepath = get_from_cache(url, force_download=True)
|
||||||
|
metadata = filename_to_url(filepath)
|
||||||
|
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
|
||||||
|
|
||||||
|
def test_standard_object_rev(self):
|
||||||
|
# Same object, but different revision
|
||||||
|
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
|
||||||
|
filepath = get_from_cache(url, force_download=True)
|
||||||
|
metadata = filename_to_url(filepath)
|
||||||
|
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
|
||||||
|
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
|
||||||
|
|
||||||
|
def test_lfs_object(self):
|
||||||
|
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
|
||||||
|
filepath = get_from_cache(url, force_download=True)
|
||||||
|
metadata = filename_to_url(filepath)
|
||||||
|
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
|
@ -20,7 +20,7 @@ import unittest
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, S3Obj
|
from transformers.hf_api import HfApi, HfFolder, ModelInfo, PresignedUrl, RepoObj, S3Obj
|
||||||
|
|
||||||
|
|
||||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||||
@ -35,6 +35,7 @@ FILES = [
|
|||||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/empty.txt"),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
REPO_NAME = "my-model-{}".format(int(time.time()))
|
||||||
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
|
ENDPOINT_STAGING = "https://moon-staging.huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
@ -78,15 +79,6 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
|||||||
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
|
urls = self._api.presign(token=self._token, filename="nested/valid_org.txt", organization="valid_org")
|
||||||
self.assertIsInstance(urls, PresignedUrl)
|
self.assertIsInstance(urls, PresignedUrl)
|
||||||
|
|
||||||
def test_presign_invalid(self):
|
|
||||||
try:
|
|
||||||
_ = self._api.presign(token=self._token, filename="non_nested.json")
|
|
||||||
except HTTPError as e:
|
|
||||||
self.assertIsNotNone(e.response.text)
|
|
||||||
self.assertTrue("Filename invalid" in e.response.text)
|
|
||||||
else:
|
|
||||||
self.fail("Expected an exception")
|
|
||||||
|
|
||||||
def test_presign(self):
|
def test_presign(self):
|
||||||
for FILE_KEY, FILE_PATH in FILES:
|
for FILE_KEY, FILE_PATH in FILES:
|
||||||
urls = self._api.presign(token=self._token, filename=FILE_KEY)
|
urls = self._api.presign(token=self._token, filename=FILE_KEY)
|
||||||
@ -109,6 +101,17 @@ class HfApiEndpointsTest(HfApiCommonTest):
|
|||||||
o = objs[-1]
|
o = objs[-1]
|
||||||
self.assertIsInstance(o, S3Obj)
|
self.assertIsInstance(o, S3Obj)
|
||||||
|
|
||||||
|
def test_list_repos_objs(self):
|
||||||
|
objs = self._api.list_repos_objs(token=self._token)
|
||||||
|
self.assertIsInstance(objs, list)
|
||||||
|
if len(objs) > 0:
|
||||||
|
o = objs[-1]
|
||||||
|
self.assertIsInstance(o, RepoObj)
|
||||||
|
|
||||||
|
def test_create_and_delete_repo(self):
|
||||||
|
self._api.create_repo(token=self._token, name=REPO_NAME)
|
||||||
|
self._api.delete_repo(token=self._token, name=REPO_NAME)
|
||||||
|
|
||||||
|
|
||||||
class HfApiPublicTest(unittest.TestCase):
|
class HfApiPublicTest(unittest.TestCase):
|
||||||
def test_staging_model_list(self):
|
def test_staging_model_list(self):
|
||||||
|
@ -323,7 +323,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_custom_load_tf_weights(self):
|
def test_custom_load_tf_weights(self):
|
||||||
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
||||||
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
|
"jplu/tiny-tf-bert-random", output_loading_info=True
|
||||||
)
|
)
|
||||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
||||||
for layer in output_loading_info["missing_keys"]:
|
for layer in output_loading_info["missing_keys"]:
|
||||||
|
@ -165,7 +165,7 @@ DUMMY_FUNCTION = {
|
|||||||
|
|
||||||
|
|
||||||
def read_init():
|
def read_init():
|
||||||
""" Read the init and exctracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
|
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user