mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
[core] implement support for run-time dependency version checking (#8645)
* implement support for run-time dependency version checking * try not escaping ! * use findall that works on py36 * small tweaks * autoformatter worship * simplify * shorter names * add support for non-versioned checks * add deps * revert * tokenizers not required, check version only if installed * make a proper distutils cmd and add make target * tqdm must be checked before tokenizers * workaround the DistributionNotFound peculiar setup * handle the rest of packages in setup.py * fully sync setup.py's install_requires - to check them all * nit * make install_requires more readable * typo * Update setup.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * restyle * add types * simplify * simplify2 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
a7d73cfdd4
commit
82d443a7fd
11
Makefile
11
Makefile
@ -1,4 +1,4 @@
|
||||
.PHONY: modified_only_fixup extra_quality_checks quality style fixup fix-copies test test-examples docs
|
||||
.PHONY: deps_table_update modified_only_fixup extra_quality_checks quality style fixup fix-copies test test-examples docs
|
||||
|
||||
|
||||
check_dirs := examples tests src utils
|
||||
@ -14,9 +14,14 @@ modified_only_fixup:
|
||||
echo "No library .py files were modified"; \
|
||||
fi
|
||||
|
||||
# Update src/transformers/dependency_versions_table.py
|
||||
|
||||
deps_table_update:
|
||||
@python setup.py deps_table_update
|
||||
|
||||
# Check that source code meets quality standards
|
||||
|
||||
extra_quality_checks:
|
||||
extra_quality_checks: deps_table_update
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_repo.py
|
||||
@ -32,7 +37,7 @@ quality:
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
|
||||
style:
|
||||
style: deps_table_update
|
||||
black $(check_dirs)
|
||||
isort $(check_dirs)
|
||||
python utils/style_doc.py src/transformers docs/source --max_len 119
|
||||
|
@ -4,11 +4,9 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import packaging
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
import pkg_resources
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
@ -30,21 +28,12 @@ from transformers.optimization import (
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
)
|
||||
from transformers.utils.versions import require_version_examples
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def require_min_ver(pkg, min_ver):
|
||||
got_ver = pkg_resources.get_distribution(pkg).version
|
||||
if packaging.version.parse(got_ver) < packaging.version.parse(min_ver):
|
||||
logger.warning(
|
||||
f"{pkg}>={min_ver} is required for a normal functioning of this module, but found {pkg}=={got_ver}. "
|
||||
"Try: pip install -r examples/requirements.txt"
|
||||
)
|
||||
|
||||
|
||||
require_min_ver("pytorch_lightning", "1.0.4")
|
||||
require_version_examples("pytorch_lightning>=1.0.4")
|
||||
|
||||
MODEL_MODES = {
|
||||
"base": AutoModel,
|
||||
|
199
setup.py
199
setup.py
@ -47,7 +47,9 @@ To create the package for pypi.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from distutils.core import Command
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
@ -69,54 +71,163 @@ if stale_egg_info.exists():
|
||||
shutil.rmtree(stale_egg_info)
|
||||
|
||||
|
||||
# IMPORTANT:
|
||||
# 1. all dependencies should be listed here with their version requirements if any
|
||||
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"black>=20.8b1",
|
||||
"cookiecutter==1.7.2",
|
||||
"dataclasses",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"filelock",
|
||||
"flake8>=3.8.3",
|
||||
"flax==0.2.2",
|
||||
"fugashi>=1.0",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.0",
|
||||
"jaxlib==0.1.55",
|
||||
"keras2onnx",
|
||||
"numpy",
|
||||
"onnxconverter-common",
|
||||
"onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime>=1.4.0",
|
||||
"packaging",
|
||||
"parameterized",
|
||||
"protobuf",
|
||||
"psutil",
|
||||
"pydantic",
|
||||
"pytest",
|
||||
"pytest-xdist",
|
||||
"python>=3.6.0",
|
||||
"recommonmark",
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"sacremoses",
|
||||
"scikit-learn",
|
||||
"sentencepiece==0.1.91",
|
||||
"sphinx-copybutton",
|
||||
"sphinx-markdown-tables",
|
||||
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
||||
"sphinx==3.2.1",
|
||||
"starlette",
|
||||
"tensorflow-cpu>=2.0",
|
||||
"tensorflow>=2.0",
|
||||
"timeout-decorator",
|
||||
"tokenizers==0.9.4",
|
||||
"torch>=1.0",
|
||||
"tqdm>=4.27",
|
||||
"unidic>=1.0.2",
|
||||
"unidic_lite>=1.0.7",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
|
||||
# tokenizers: "tokenizers==0.9.4" lookup table
|
||||
# support non-versions file too so that they can be checked at run time
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)}
|
||||
|
||||
|
||||
def deps_list(*pkgs):
|
||||
return [deps[pkg] for pkg in pkgs]
|
||||
|
||||
|
||||
class DepsTableUpdateCommand(Command):
|
||||
"""
|
||||
A custom distutils command that updates the dependency table.
|
||||
usage: python setup.py deps_table_update
|
||||
"""
|
||||
|
||||
description = "build runtime dependency table"
|
||||
user_options = [
|
||||
# format: (long option, short option, description).
|
||||
("dep-table-update", None, "updates src/transformers/dependency_versions_table.py"),
|
||||
]
|
||||
|
||||
def initialize_options(self):
|
||||
pass
|
||||
|
||||
def finalize_options(self):
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()])
|
||||
content = [
|
||||
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
|
||||
"# 1. modify the `_deps` dict in setup.py",
|
||||
"# 2. run `make deps_table_update``",
|
||||
"deps = {",
|
||||
entries,
|
||||
"}",
|
||||
""
|
||||
]
|
||||
target = "src/transformers/dependency_versions_table.py"
|
||||
print(f"updating {target}")
|
||||
with open(target, "w") as f:
|
||||
f.write("\n".join(content))
|
||||
|
||||
|
||||
extras = {}
|
||||
|
||||
extras["ja"] = ["fugashi>=1.0", "ipadic>=1.0.0,<2.0", "unidic_lite>=1.0.7", "unidic>=1.0.2"]
|
||||
extras["sklearn"] = ["scikit-learn"]
|
||||
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
|
||||
extras["sklearn"] = deps_list("scikit-learn")
|
||||
|
||||
# keras2onnx and onnxconverter-common version is specific through a commit until 1.7.0 lands on pypi
|
||||
extras["tf"] = [
|
||||
"tensorflow>=2.0",
|
||||
"onnxconverter-common",
|
||||
"keras2onnx"
|
||||
# "onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
|
||||
# "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx",
|
||||
]
|
||||
extras["tf-cpu"] = [
|
||||
"tensorflow-cpu>=2.0",
|
||||
"onnxconverter-common",
|
||||
"keras2onnx"
|
||||
# "onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
|
||||
# "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx",
|
||||
]
|
||||
extras["torch"] = ["torch>=1.0"]
|
||||
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx")
|
||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx")
|
||||
|
||||
extras["torch"] = deps_list("torch")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["retrieval"] = ["datasets"] # faiss is not supported on windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
extras["retrieval"] = deps_list("datasets") # faiss is not supported on windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
else:
|
||||
extras["retrieval"] = ["faiss-cpu", "datasets"]
|
||||
extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"]
|
||||
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
|
||||
extras["flax"] = deps_list("jax", "jaxlib", "flax")
|
||||
|
||||
extras["tokenizers"] = ["tokenizers==0.9.4"]
|
||||
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
|
||||
extras["modelcreation"] = ["cookiecutter==1.7.2"]
|
||||
extras["tokenizers"] = deps_list("tokenizers")
|
||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||
extras["modelcreation"] = deps_list("cookiecutter")
|
||||
|
||||
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||
|
||||
extras["sentencepiece"] = ["sentencepiece==0.1.91", "protobuf"]
|
||||
extras["retrieval"] = ["faiss-cpu", "datasets"]
|
||||
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil"] + extras["retrieval"] + extras["modelcreation"]
|
||||
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
||||
extras["docs"] = ["recommonmark", "sphinx==3.2.1", "sphinx-markdown-tables", "sphinx-rtd-theme==0.4.3", "sphinx-copybutton"]
|
||||
extras["quality"] = ["black >= 20.8b1", "isort >= 5.5.4", "flake8 >= 3.8.3"]
|
||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||
|
||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
|
||||
extras["testing"] = (
|
||||
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil")
|
||||
+ extras["retrieval"]
|
||||
+ extras["modelcreation"]
|
||||
)
|
||||
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
|
||||
extras["quality"] = deps_list("black", "isort", "flake8")
|
||||
|
||||
extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"]
|
||||
|
||||
extras["dev"] = extras["all"] + extras["testing"] + extras["quality"] + extras["ja"] + extras["docs"] + extras["sklearn"] + extras["modelcreation"]
|
||||
extras["dev"] = (
|
||||
extras["all"]
|
||||
+ extras["testing"]
|
||||
+ extras["quality"]
|
||||
+ extras["ja"]
|
||||
+ extras["docs"]
|
||||
+ extras["sklearn"]
|
||||
+ extras["modelcreation"]
|
||||
)
|
||||
|
||||
|
||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||
install_requires = [
|
||||
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||
deps["numpy"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["regex"], # for OpenAI GPT
|
||||
deps["requests"], # for downloading models over HTTPS
|
||||
deps["sacremoses"], # for XLM
|
||||
deps["tokenizers"],
|
||||
deps["tqdm"], # progress bars in model download and training scripts
|
||||
]
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.0.0-rc-1",
|
||||
@ -130,27 +241,10 @@ setup(
|
||||
url="https://github.com/huggingface/transformers",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"tokenizers == 0.9.4",
|
||||
# dataclasses for Python versions that don't have it
|
||||
"dataclasses;python_version<'3.7'",
|
||||
# utilities from PyPA to e.g. compare versions
|
||||
"packaging",
|
||||
# filesystem locks e.g. to prevent parallel downloads
|
||||
"filelock",
|
||||
# for downloading models over HTTPS
|
||||
"requests",
|
||||
# progress bars in model download and training scripts
|
||||
"tqdm >= 4.27",
|
||||
# for OpenAI GPT
|
||||
"regex != 2019.12.17",
|
||||
# for XLM
|
||||
"sacremoses",
|
||||
],
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["transformers-cli=transformers.commands.transformers_cli:main"]},
|
||||
python_requires=">=3.6.0",
|
||||
install_requires=install_requires,
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
@ -163,4 +257,5 @@ setup(
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
cmdclass={"deps_table_update": DepsTableUpdateCommand},
|
||||
)
|
||||
|
@ -17,15 +17,7 @@ else:
|
||||
absl.logging.set_stderrthreshold("info")
|
||||
absl.logging._warn_preinit_stderr = False
|
||||
|
||||
# Integrations: this needs to come before other ml imports
|
||||
# in order to allow any 3rd-party code to initialize properly
|
||||
from .integrations import ( # isort:skip
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from . import dependency_versions_check
|
||||
|
||||
# Configuration
|
||||
from .configuration_utils import PretrainedConfig
|
||||
@ -203,6 +195,17 @@ from .tokenization_utils_base import (
|
||||
)
|
||||
|
||||
|
||||
# Integrations: this needs to come before other ml imports
|
||||
# in order to allow any 3rd-party code to initialize properly
|
||||
from .integrations import ( # isort:skip
|
||||
is_comet_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .models.albert import AlbertTokenizer
|
||||
from .models.bert_generation import BertGenerationTokenizer
|
||||
|
28
src/transformers/dependency_versions_check.py
Normal file
28
src/transformers/dependency_versions_check.py
Normal file
@ -0,0 +1,28 @@
|
||||
import sys
|
||||
|
||||
from .dependency_versions_table import deps
|
||||
from .utils.versions import require_version_core
|
||||
|
||||
|
||||
# define which module versions we always want to check at run time
|
||||
# (usually the ones defined in `install_requires` in setup.py)
|
||||
#
|
||||
# order specific notes:
|
||||
# - tqdm must be checked before tokenizers
|
||||
|
||||
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
|
||||
if sys.version_info < (3, 7):
|
||||
pkgs_to_check_at_runtime.append("dataclasses")
|
||||
|
||||
for pkg in pkgs_to_check_at_runtime:
|
||||
if pkg in deps:
|
||||
if pkg == "tokenizers":
|
||||
# must be loaded here, or else tqdm check may fail
|
||||
from .file_utils import is_tokenizers_available
|
||||
|
||||
if not is_tokenizers_available():
|
||||
continue # not required, check version only if installed
|
||||
|
||||
require_version_core(deps[pkg])
|
||||
else:
|
||||
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
52
src/transformers/dependency_versions_table.py
Normal file
52
src/transformers/dependency_versions_table.py
Normal file
@ -0,0 +1,52 @@
|
||||
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
||||
# 1. modify the `_deps` dict in setup.py
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"black": "black>=20.8b1",
|
||||
"cookiecutter": "cookiecutter==1.7.2",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets",
|
||||
"faiss-cpu": "faiss-cpu",
|
||||
"fastapi": "fastapi",
|
||||
"filelock": "filelock",
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"flax": "flax==0.2.2",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.0",
|
||||
"jaxlib": "jaxlib==0.1.55",
|
||||
"keras2onnx": "keras2onnx",
|
||||
"numpy": "numpy",
|
||||
"onnxconverter-common": "onnxconverter-common",
|
||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||
"onnxruntime": "onnxruntime>=1.4.0",
|
||||
"packaging": "packaging",
|
||||
"parameterized": "parameterized",
|
||||
"protobuf": "protobuf",
|
||||
"psutil": "psutil",
|
||||
"pydantic": "pydantic",
|
||||
"pytest": "pytest",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"python": "python>=3.6.0",
|
||||
"recommonmark": "recommonmark",
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"sacremoses": "sacremoses",
|
||||
"scikit-learn": "scikit-learn",
|
||||
"sentencepiece": "sentencepiece==0.1.91",
|
||||
"sphinx-copybutton": "sphinx-copybutton",
|
||||
"sphinx-markdown-tables": "sphinx-markdown-tables",
|
||||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||
"sphinx": "sphinx==3.2.1",
|
||||
"starlette": "starlette",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.0",
|
||||
"tensorflow": "tensorflow>=2.0",
|
||||
"timeout-decorator": "timeout-decorator",
|
||||
"tokenizers": "tokenizers==0.9.4",
|
||||
"torch": "torch>=1.0",
|
||||
"tqdm": "tqdm>=4.27",
|
||||
"unidic": "unidic>=1.0.2",
|
||||
"unidic_lite": "unidic_lite>=1.0.7",
|
||||
"uvicorn": "uvicorn",
|
||||
}
|
87
src/transformers/utils/versions.py
Normal file
87
src/transformers/utils/versions.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Utilities for working with package versions
|
||||
"""
|
||||
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
import pkg_resources
|
||||
|
||||
|
||||
ops = {
|
||||
"<": operator.lt,
|
||||
"<=": operator.le,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
">=": operator.ge,
|
||||
">": operator.gt,
|
||||
}
|
||||
|
||||
|
||||
def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||
"""
|
||||
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
||||
|
||||
The installed module version comes from the `site-packages` dir via `pkg_resources`.
|
||||
|
||||
Args:
|
||||
requirement (:obj:`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
|
||||
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
|
||||
"""
|
||||
|
||||
# note: while pkg_resources.require_version(requirement) is a much simpler way to do it, it
|
||||
# fails if some of the dependencies of the dependencies are not matching, which is not necessarily
|
||||
# bad, hence the more complicated check - which also should be faster, since it doesn't check
|
||||
# dependencies of dependencies.
|
||||
|
||||
hint = f"\n{hint}" if hint is not None else ""
|
||||
|
||||
# non-versioned check
|
||||
if re.match(r"^[\w_\-\d]+$", requirement):
|
||||
pkg, op, want_ver = requirement, None, None
|
||||
else:
|
||||
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2})(.+)", requirement)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
|
||||
)
|
||||
pkg, op, want_ver = match[0]
|
||||
if op not in ops:
|
||||
raise ValueError(f"need one of {list(ops.keys())}, but got {op}")
|
||||
|
||||
# special case
|
||||
if pkg == "python":
|
||||
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
raise pkg_resources.VersionConflict(
|
||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}."
|
||||
)
|
||||
return
|
||||
|
||||
# check if any version is installed
|
||||
try:
|
||||
got_ver = pkg_resources.get_distribution(pkg).version
|
||||
except pkg_resources.DistributionNotFound:
|
||||
raise pkg_resources.DistributionNotFound(requirement, ["this application", hint])
|
||||
|
||||
# check that the right version is installed if version number was provided
|
||||
if want_ver is not None and not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
raise pkg_resources.VersionConflict(
|
||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
||||
)
|
||||
|
||||
|
||||
def require_version_core(requirement):
|
||||
""" require_version wrapper which emits a core-specific hint on failure """
|
||||
hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master"
|
||||
return require_version(requirement, hint)
|
||||
|
||||
|
||||
def require_version_examples(requirement):
|
||||
""" require_version wrapper which emits examples-specific hint on failure """
|
||||
hint = "Try: pip install -r examples/requirements.txt"
|
||||
return require_version(requirement, hint)
|
91
tests/test_versions_utils.py
Normal file
91
tests/test_versions_utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
import sys
|
||||
|
||||
import numpy
|
||||
|
||||
import pkg_resources
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
from transformers.utils.versions import require_version, require_version_core, require_version_examples
|
||||
|
||||
|
||||
numpy_ver = numpy.__version__
|
||||
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||
|
||||
|
||||
class DependencyVersionCheckTest(TestCasePlus):
|
||||
def test_core(self):
|
||||
# lt + different version strings
|
||||
require_version_core("numpy<1000.4.5")
|
||||
require_version_core("numpy<1000.4")
|
||||
require_version_core("numpy<1000")
|
||||
|
||||
# le
|
||||
require_version_core("numpy<=1000.4.5")
|
||||
require_version_core(f"numpy<={numpy_ver}")
|
||||
|
||||
# eq
|
||||
require_version_core(f"numpy=={numpy_ver}")
|
||||
|
||||
# ne
|
||||
require_version_core("numpy!=1000.4.5")
|
||||
|
||||
# ge
|
||||
require_version_core("numpy>=1.0")
|
||||
require_version_core("numpy>=1.0.0")
|
||||
require_version_core(f"numpy>={numpy_ver}")
|
||||
|
||||
# gt
|
||||
require_version_core("numpy>1.0.0")
|
||||
|
||||
# requirement w/o version
|
||||
require_version_core("numpy")
|
||||
|
||||
# unmet requirements due to version conflict
|
||||
for req in ["numpy==1.0.0", "numpy>=1000.0.0", f"numpy<{numpy_ver}"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except pkg_resources.VersionConflict as e:
|
||||
self.assertIn(f"{req} is required", str(e))
|
||||
self.assertIn("but found", str(e))
|
||||
|
||||
# unmet requirements due to missing module
|
||||
for req in ["numpipypie>1", "numpipypie2"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except pkg_resources.DistributionNotFound as e:
|
||||
self.assertIn(f"The '{req}' distribution was not found and is required by this application", str(e))
|
||||
self.assertIn("Try: pip install transformers -U", str(e))
|
||||
|
||||
# bogus requirements formats:
|
||||
# 1. whole thing
|
||||
for req in ["numpy??1.0.0", "numpy1.0.0"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except ValueError as e:
|
||||
self.assertIn("requirement needs to be in the pip package format", str(e))
|
||||
# 2. only operators
|
||||
for req in ["numpy=1.0.0", "numpy == 1.00", "numpy<>1.0.0", "numpy><1.00", "numpy>>1.0.0"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except ValueError as e:
|
||||
self.assertIn("need one of ", str(e))
|
||||
|
||||
def test_examples(self):
|
||||
# the main functionality is tested in `test_core`, this is just the hint check
|
||||
try:
|
||||
require_version_examples("numpy>1000.4.5")
|
||||
except pkg_resources.VersionConflict as e:
|
||||
self.assertIn("is required", str(e))
|
||||
self.assertIn("pip install -r examples/requirements.txt", str(e))
|
||||
|
||||
def test_python(self):
|
||||
|
||||
# matching requirement
|
||||
require_version("python>=3.6.0")
|
||||
|
||||
# not matching requirements
|
||||
for req in ["python>9.9.9", "python<3.0.0"]:
|
||||
try:
|
||||
require_version_core(req)
|
||||
except pkg_resources.VersionConflict as e:
|
||||
self.assertIn(f"{req} is required", str(e))
|
||||
self.assertIn(f"but found python=={python_ver}", str(e))
|
Loading…
Reference in New Issue
Block a user