mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 22:30:09 +06:00
Fast transformers import part 1 (#9441)
* Don't import libs to check they are available * Don't import integrations at init * Add importlib_metdata to deps * Remove old vars references * Avoid syntax error * Adapt testing utils * Try to appease torchhub * Add dependency * Remove more private variables * Fix typo * Another typo * Refine the tf availability test
This commit is contained in:
parent
c89f1bc92e
commit
0c96262f7d
2
.github/workflows/github-torch-hub.yml
vendored
2
.github/workflows/github-torch-hub.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install torch
|
pip install torch
|
||||||
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging
|
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging importlib_metadata
|
||||||
|
|
||||||
- name: Torch hub list
|
- name: Torch hub list
|
||||||
run: |
|
run: |
|
||||||
|
@ -30,7 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(AutoConfig.__doc__)
|
@add_start_docstrings(AutoConfig.__doc__)
|
||||||
|
2
setup.py
2
setup.py
@ -99,6 +99,7 @@ _deps = [
|
|||||||
"flake8>=3.8.3",
|
"flake8>=3.8.3",
|
||||||
"flax>=0.2.2",
|
"flax>=0.2.2",
|
||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
|
"importlib_metadata",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
"jax>=0.2.0",
|
"jax>=0.2.0",
|
||||||
@ -232,6 +233,7 @@ extras["dev"] = (
|
|||||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||||
install_requires = [
|
install_requires = [
|
||||||
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
||||||
|
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
|
||||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||||
deps["numpy"],
|
deps["numpy"],
|
||||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||||
|
@ -26,6 +26,8 @@ from .utils.versions import require_version_core
|
|||||||
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
|
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
|
||||||
if sys.version_info < (3, 7):
|
if sys.version_info < (3, 7):
|
||||||
pkgs_to_check_at_runtime.append("dataclasses")
|
pkgs_to_check_at_runtime.append("dataclasses")
|
||||||
|
if sys.version_info < (3, 8):
|
||||||
|
pkgs_to_check_at_runtime.append("importlib_metadata")
|
||||||
|
|
||||||
for pkg in pkgs_to_check_at_runtime:
|
for pkg in pkgs_to_check_at_runtime:
|
||||||
if pkg in deps:
|
if pkg in deps:
|
||||||
|
@ -12,6 +12,7 @@ deps = {
|
|||||||
"flake8": "flake8>=3.8.3",
|
"flake8": "flake8>=3.8.3",
|
||||||
"flax": "flax>=0.2.2",
|
"flax": "flax>=0.2.2",
|
||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
|
"importlib_metadata": "importlib_metadata",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
"jax": "jax>=0.2.0",
|
"jax": "jax>=0.2.0",
|
||||||
|
@ -18,6 +18,7 @@ https://github.com/allenai/allennlp.
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
import importlib.util
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -37,8 +38,10 @@ from urllib.parse import urlparse
|
|||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from packaging import version
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import importlib_metadata
|
||||||
import requests
|
import requests
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||||
|
|
||||||
try:
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
_torch_available = True # pylint: disable=invalid-name
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
_torch_available = importlib.util.find_spec("torch") is not None
|
||||||
else:
|
if _torch_available:
|
||||||
|
try:
|
||||||
|
_torch_version = importlib_metadata.version("torch")
|
||||||
|
logger.info(f"PyTorch version {_torch_version} available.")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_torch_available = False
|
||||||
|
else:
|
||||||
logger.info("Disabling PyTorch because USE_TF is set")
|
logger.info("Disabling PyTorch because USE_TF is set")
|
||||||
_torch_available = False
|
_torch_available = False
|
||||||
except ImportError:
|
|
||||||
_torch_available = False # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
try:
|
|
||||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
||||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
||||||
|
|
||||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||||
import tensorflow as tf
|
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||||
|
if _tf_available:
|
||||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||||
_tf_available = True # pylint: disable=invalid-name
|
try:
|
||||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
_tf_version = importlib_metadata.version("tensorflow")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
try:
|
||||||
|
_tf_version = importlib_metadata.version("tensorflow-cpu")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_tf_version = None
|
||||||
|
_tf_available = False
|
||||||
|
if _tf_available:
|
||||||
|
if version.parse(_tf_version) < version.parse("2"):
|
||||||
|
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
|
||||||
|
_tf_available = False
|
||||||
else:
|
else:
|
||||||
|
logger.info(f"TensorFlow version {_tf_version} available.")
|
||||||
|
else:
|
||||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||||
_tf_available = False
|
_tf_available = False
|
||||||
except (ImportError, AssertionError):
|
|
||||||
_tf_available = False # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||||
|
if _flax_available:
|
||||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
try:
|
||||||
import flax
|
_jax_version = importlib_metadata.version("jax")
|
||||||
import jax
|
_flax_version = importlib_metadata.version("flax")
|
||||||
|
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
||||||
logger.info("JAX version {}, Flax: available".format(jax.__version__))
|
except importlib_metadata.PackageNotFoundError:
|
||||||
logger.info("Flax available: {}".format(flax))
|
_flax_available = False
|
||||||
_flax_available = True
|
else:
|
||||||
else:
|
|
||||||
_flax_available = False
|
_flax_available = False
|
||||||
except ImportError:
|
|
||||||
_flax_available = False # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
|
_datasets_available = importlib.util.find_spec("datasets") is not None
|
||||||
try:
|
try:
|
||||||
import datasets # noqa: F401
|
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
|
||||||
|
# AND checking it has an author field in the metadata that is HuggingFace.
|
||||||
# Check we're not importing a "datasets" directory somewhere
|
_ = importlib_metadata.version("datasets")
|
||||||
_datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset")
|
_datasets_metadata = importlib_metadata.metadata("datasets")
|
||||||
if _datasets_available:
|
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
|
||||||
logger.debug(f"Successfully imported datasets version {datasets.__version__}")
|
_datasets_available = False
|
||||||
else:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_datasets_available = False
|
_datasets_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_faiss_available = importlib.util.find_spec("faiss") is not None
|
||||||
try:
|
try:
|
||||||
from torch.hub import _get_torch_home
|
_faiss_version = importlib_metadata.version("faiss")
|
||||||
|
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
||||||
torch_cache_home = _get_torch_home()
|
except importlib_metadata.PackageNotFoundError:
|
||||||
except ImportError:
|
|
||||||
torch_cache_home = os.path.expanduser(
|
|
||||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch_xla.core.xla_model as xm # noqa: F401
|
|
||||||
|
|
||||||
if _torch_available:
|
|
||||||
_torch_tpu_available = True # pylint: disable=
|
|
||||||
else:
|
|
||||||
_torch_tpu_available = False
|
|
||||||
except ImportError:
|
|
||||||
_torch_tpu_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import psutil # noqa: F401
|
|
||||||
|
|
||||||
_psutil_available = True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_psutil_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import py3nvml # noqa: F401
|
|
||||||
|
|
||||||
_py3nvml_available = True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_py3nvml_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from apex import amp # noqa: F401
|
|
||||||
|
|
||||||
_has_apex = True
|
|
||||||
except ImportError:
|
|
||||||
_has_apex = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import faiss # noqa: F401
|
|
||||||
|
|
||||||
_faiss_available = True
|
|
||||||
logger.debug(f"Successfully imported faiss version {faiss.__version__}")
|
|
||||||
except ImportError:
|
|
||||||
_faiss_available = False
|
_faiss_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||||
try:
|
try:
|
||||||
import sklearn.metrics # noqa: F401
|
_scatter_version = importlib_metadata.version("torch_scatterr")
|
||||||
|
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
||||||
import scipy.stats # noqa: F401
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
|
||||||
_has_sklearn = True
|
|
||||||
except (AttributeError, ImportError):
|
|
||||||
_has_sklearn = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
|
||||||
get_ipython = sys.modules["IPython"].get_ipython
|
|
||||||
if "IPKernelApp" not in get_ipython().config:
|
|
||||||
raise ImportError("console")
|
|
||||||
if "VSCODE_PID" in os.environ:
|
|
||||||
raise ImportError("vscode")
|
|
||||||
|
|
||||||
import IPython # noqa: F401
|
|
||||||
|
|
||||||
_in_notebook = True
|
|
||||||
except (AttributeError, ImportError, KeyError):
|
|
||||||
_in_notebook = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import sentencepiece # noqa: F401
|
|
||||||
|
|
||||||
_sentencepiece_available = True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_sentencepiece_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import google.protobuf # noqa: F401
|
|
||||||
|
|
||||||
_protobuf_available = True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_protobuf_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import tokenizers # noqa: F401
|
|
||||||
|
|
||||||
_tokenizers_available = True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_tokenizers_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pandas # noqa: F401
|
|
||||||
|
|
||||||
_pandas_available = True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_pandas_available = False
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch_scatter
|
|
||||||
|
|
||||||
# Check we're not importing a "torch_scatter" directory somewhere
|
|
||||||
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
|
|
||||||
if _scatter_available:
|
|
||||||
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
|
|
||||||
else:
|
|
||||||
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
_scatter_available = False
|
_scatter_available = False
|
||||||
|
|
||||||
|
|
||||||
|
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||||
# New default cache, shared with the Datasets library
|
# New default cache, shared with the Datasets library
|
||||||
hf_cache_home = os.path.expanduser(
|
hf_cache_home = os.path.expanduser(
|
||||||
@ -308,7 +204,14 @@ def is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_torch_tpu_available():
|
def is_torch_tpu_available():
|
||||||
return _torch_tpu_available
|
if not _torch_available:
|
||||||
|
return False
|
||||||
|
# This test is probably enough, but just in case, we unpack a bit.
|
||||||
|
if importlib.util.find_spec("torch_xla") is None:
|
||||||
|
return False
|
||||||
|
if importlib.util.find_spec("torch_xla.core") is None:
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_datasets_available():
|
def is_datasets_available():
|
||||||
@ -316,15 +219,15 @@ def is_datasets_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_psutil_available():
|
def is_psutil_available():
|
||||||
return _psutil_available
|
return importlib.util.find_spec("psutil") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_py3nvml_available():
|
def is_py3nvml_available():
|
||||||
return _py3nvml_available
|
return importlib.util.find_spec("py3nvml") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_apex_available():
|
def is_apex_available():
|
||||||
return _has_apex
|
return importlib.util.find_spec("apex") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_faiss_available():
|
def is_faiss_available():
|
||||||
@ -332,23 +235,39 @@ def is_faiss_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_sklearn_available():
|
def is_sklearn_available():
|
||||||
return _has_sklearn
|
if importlib.util.find_spec("sklearn") is None:
|
||||||
|
return False
|
||||||
|
if importlib.util.find_spec("scipy") is None:
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
|
||||||
|
|
||||||
|
|
||||||
def is_sentencepiece_available():
|
def is_sentencepiece_available():
|
||||||
return _sentencepiece_available
|
return importlib.util.find_spec("sentencepiece") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_protobuf_available():
|
def is_protobuf_available():
|
||||||
return _protobuf_available
|
if importlib.util.find_spec("google") is None:
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("google.protobuf") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_tokenizers_available():
|
def is_tokenizers_available():
|
||||||
return _tokenizers_available
|
return importlib.util.find_spec("tokenizers") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_in_notebook():
|
def is_in_notebook():
|
||||||
return _in_notebook
|
try:
|
||||||
|
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||||
|
get_ipython = sys.modules["IPython"].get_ipython
|
||||||
|
if "IPKernelApp" not in get_ipython().config:
|
||||||
|
raise ImportError("console")
|
||||||
|
if "VSCODE_PID" in os.environ:
|
||||||
|
raise ImportError("vscode")
|
||||||
|
|
||||||
|
return importlib.util.find_spec("IPython") is not None
|
||||||
|
except (AttributeError, ImportError, KeyError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_scatter_available():
|
def is_scatter_available():
|
||||||
@ -356,7 +275,7 @@ def is_scatter_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_pandas_available():
|
def is_pandas_available():
|
||||||
return _pandas_available
|
return importlib.util.find_spec("pandas") is not None
|
||||||
|
|
||||||
|
|
||||||
def torch_only_method(fn):
|
def torch_only_method(fn):
|
||||||
@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|||||||
"""
|
"""
|
||||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
ua += "; torch/{}".format(torch.__version__)
|
ua += f"; torch/{_torch_version}"
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
ua += "; tensorflow/{}".format(tf.__version__)
|
ua += f"; tensorflow/{_tf_version}"
|
||||||
if isinstance(user_agent, dict):
|
if isinstance(user_agent, dict):
|
||||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||||
elif isinstance(user_agent, str):
|
elif isinstance(user_agent, str):
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
"""
|
"""
|
||||||
Integrations with other Python libraries.
|
Integrations with other Python libraries.
|
||||||
"""
|
"""
|
||||||
|
import importlib.util
|
||||||
import math
|
import math
|
||||||
import numbers
|
import numbers
|
||||||
import os
|
import os
|
||||||
@ -21,18 +22,16 @@ import re
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .file_utils import ENV_VARS_TRUE_VALUES
|
|
||||||
from .trainer_utils import EvaluationStrategy
|
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Import 3rd-party integrations before ML frameworks:
|
# comet_ml requires to be imported before any ML frameworks
|
||||||
|
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
||||||
try:
|
if _has_comet:
|
||||||
# Comet needs to be imported before any ML frameworks
|
try:
|
||||||
import comet_ml # noqa: F401
|
import comet_ml # noqa: F401
|
||||||
|
|
||||||
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
|
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
|
||||||
@ -41,87 +40,20 @@ try:
|
|||||||
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
|
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
|
||||||
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
|
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
|
||||||
_has_comet = False
|
_has_comet = False
|
||||||
except (ImportError, ValueError):
|
except (ImportError, ValueError):
|
||||||
_has_comet = False
|
_has_comet = False
|
||||||
|
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
wandb.ensure_configured()
|
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
|
||||||
if wandb.api.api_key is None:
|
|
||||||
_has_wandb = False
|
|
||||||
if os.getenv("WANDB_DISABLED"):
|
|
||||||
logger.warning("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
|
|
||||||
else:
|
|
||||||
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
_has_wandb = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import optuna # noqa: F401
|
|
||||||
|
|
||||||
_has_optuna = True
|
|
||||||
except (ImportError):
|
|
||||||
_has_optuna = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import ray # noqa: F401
|
|
||||||
|
|
||||||
_has_ray = True
|
|
||||||
try:
|
|
||||||
# Ray Tune has additional dependencies.
|
|
||||||
from ray import tune # noqa: F401
|
|
||||||
|
|
||||||
_has_ray_tune = True
|
|
||||||
except (ImportError):
|
|
||||||
_has_ray_tune = False
|
|
||||||
except (ImportError):
|
|
||||||
_has_ray = False
|
|
||||||
_has_ray_tune = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
|
||||||
|
|
||||||
_has_tensorboard = True
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
from tensorboardX import SummaryWriter # noqa: F401
|
|
||||||
|
|
||||||
_has_tensorboard = True
|
|
||||||
except ImportError:
|
|
||||||
_has_tensorboard = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from azureml.core.run import Run # noqa: F401
|
|
||||||
|
|
||||||
_has_azureml = True
|
|
||||||
except ImportError:
|
|
||||||
_has_azureml = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import mlflow # noqa: F401
|
|
||||||
|
|
||||||
_has_mlflow = True
|
|
||||||
except ImportError:
|
|
||||||
_has_mlflow = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import fairscale # noqa: F401
|
|
||||||
|
|
||||||
_has_fairscale = True
|
|
||||||
except ImportError:
|
|
||||||
_has_fairscale = False
|
|
||||||
|
|
||||||
# No transformer imports above this point
|
|
||||||
|
|
||||||
from .file_utils import is_torch_tpu_available # noqa: E402
|
|
||||||
from .trainer_callback import TrainerCallback # noqa: E402
|
from .trainer_callback import TrainerCallback # noqa: E402
|
||||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # noqa: E402
|
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
# Integration functions:
|
# Integration functions:
|
||||||
def is_wandb_available():
|
def is_wandb_available():
|
||||||
return _has_wandb
|
if os.getenv("WANDB_DISABLED"):
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("wandb") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_comet_available():
|
def is_comet_available():
|
||||||
@ -129,35 +61,43 @@ def is_comet_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_tensorboard_available():
|
def is_tensorboard_available():
|
||||||
return _has_tensorboard
|
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_optuna_available():
|
def is_optuna_available():
|
||||||
return _has_optuna
|
return importlib.util.find_spec("optuna") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_ray_available():
|
def is_ray_available():
|
||||||
return _has_ray
|
return importlib.util.find_spec("ray") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_ray_tune_available():
|
def is_ray_tune_available():
|
||||||
return _has_ray_tune
|
if not is_ray_available():
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("ray.tune") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_azureml_available():
|
def is_azureml_available():
|
||||||
return _has_azureml
|
if importlib.util.find_spec("azureml") is None:
|
||||||
|
return False
|
||||||
|
if importlib.util.find_spec("azureml.core") is None:
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("azureml.core.run") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
def is_mlflow_available():
|
||||||
return _has_mlflow
|
return importlib.util.find_spec("mlflow") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_fairscale_available():
|
def is_fairscale_available():
|
||||||
return _has_fairscale
|
return importlib.util.find_spec("fairscale") is not None
|
||||||
|
|
||||||
|
|
||||||
def hp_params(trial):
|
def hp_params(trial):
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
|
import optuna
|
||||||
|
|
||||||
if isinstance(trial, optuna.Trial):
|
if isinstance(trial, optuna.Trial):
|
||||||
return trial.params
|
return trial.params
|
||||||
if is_ray_tune_available():
|
if is_ray_tune_available():
|
||||||
@ -175,6 +115,8 @@ def default_hp_search_backend():
|
|||||||
|
|
||||||
|
|
||||||
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||||
|
import optuna
|
||||||
|
|
||||||
def _objective(trial, checkpoint_dir=None):
|
def _objective(trial, checkpoint_dir=None):
|
||||||
model_path = None
|
model_path = None
|
||||||
if checkpoint_dir:
|
if checkpoint_dir:
|
||||||
@ -198,6 +140,8 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
|||||||
|
|
||||||
|
|
||||||
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||||
|
import ray
|
||||||
|
|
||||||
def _objective(trial, checkpoint_dir=None):
|
def _objective(trial, checkpoint_dir=None):
|
||||||
model_path = None
|
model_path = None
|
||||||
if checkpoint_dir:
|
if checkpoint_dir:
|
||||||
@ -297,14 +241,29 @@ class TensorBoardCallback(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tb_writer=None):
|
def __init__(self, tb_writer=None):
|
||||||
|
has_tensorboard = is_tensorboard_available()
|
||||||
assert (
|
assert (
|
||||||
_has_tensorboard
|
has_tensorboard
|
||||||
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
||||||
|
if has_tensorboard:
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||||
|
|
||||||
|
self._SummaryWriter = SummaryWriter
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
self._SummaryWriter = SummaryWriter
|
||||||
|
except ImportError:
|
||||||
|
self._SummaryWriter = None
|
||||||
self.tb_writer = tb_writer
|
self.tb_writer = tb_writer
|
||||||
|
self._SummaryWriter = SummaryWriter
|
||||||
|
|
||||||
def _init_summary_writer(self, args, log_dir=None):
|
def _init_summary_writer(self, args, log_dir=None):
|
||||||
log_dir = log_dir or args.logging_dir
|
log_dir = log_dir or args.logging_dir
|
||||||
self.tb_writer = SummaryWriter(log_dir=log_dir)
|
if self._SummaryWriter is not None:
|
||||||
|
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, **kwargs):
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
if not state.is_world_process_zero:
|
if not state.is_world_process_zero:
|
||||||
@ -335,7 +294,7 @@ class TensorBoardCallback(TrainerCallback):
|
|||||||
if self.tb_writer is None:
|
if self.tb_writer is None:
|
||||||
self._init_summary_writer(args)
|
self._init_summary_writer(args)
|
||||||
|
|
||||||
if self.tb_writer:
|
if self.tb_writer is not None:
|
||||||
logs = rewrite_logs(logs)
|
logs = rewrite_logs(logs)
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
@ -363,7 +322,20 @@ class WandbCallback(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
has_wandb = is_wandb_available()
|
||||||
|
assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||||
|
if has_wandb:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.ensure_configured()
|
||||||
|
if wandb.api.api_key is None:
|
||||||
|
has_wandb = False
|
||||||
|
logger.warning(
|
||||||
|
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
|
||||||
|
)
|
||||||
|
self._wandb = wandb
|
||||||
|
else:
|
||||||
|
self._wandb = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def setup(self, args, state, model, reinit, **kwargs):
|
def setup(self, args, state, model, reinit, **kwargs):
|
||||||
@ -384,6 +356,8 @@ class WandbCallback(TrainerCallback):
|
|||||||
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to disable wandb entirely.
|
Whether or not to disable wandb entirely.
|
||||||
"""
|
"""
|
||||||
|
if self._wandb is None:
|
||||||
|
return
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -402,7 +376,7 @@ class WandbCallback(TrainerCallback):
|
|||||||
else:
|
else:
|
||||||
run_name = args.run_name
|
run_name = args.run_name
|
||||||
|
|
||||||
wandb.init(
|
self._wandb.init(
|
||||||
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||||
config=combined_dict,
|
config=combined_dict,
|
||||||
name=run_name,
|
name=run_name,
|
||||||
@ -412,19 +386,25 @@ class WandbCallback(TrainerCallback):
|
|||||||
|
|
||||||
# keep track of model topology and gradients, unsupported on TPU
|
# keep track of model topology and gradients, unsupported on TPU
|
||||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||||
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
self._wandb.watch(
|
||||||
|
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
|
||||||
|
)
|
||||||
|
|
||||||
# log outputs
|
# log outputs
|
||||||
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
if self._wandb is None:
|
||||||
|
return
|
||||||
hp_search = state.is_hyper_param_search
|
hp_search = state.is_hyper_param_search
|
||||||
if not self._initialized or hp_search:
|
if not self._initialized or hp_search:
|
||||||
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
||||||
|
if self._wandb is None:
|
||||||
|
return
|
||||||
# commit last step
|
# commit last step
|
||||||
wandb.log({})
|
self._wandb.log({})
|
||||||
if self._log_model and self._initialized and state.is_world_process_zero:
|
if self._log_model and self._initialized and state.is_world_process_zero:
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
|
|
||||||
@ -432,11 +412,11 @@ class WandbCallback(TrainerCallback):
|
|||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
fake_trainer.save_model(temp_dir)
|
fake_trainer.save_model(temp_dir)
|
||||||
# use run name and ensure it's a valid Artifact name
|
# use run name and ensure it's a valid Artifact name
|
||||||
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
|
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
|
||||||
metadata = (
|
metadata = (
|
||||||
{
|
{
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(wandb.summary).items()
|
for k, v in dict(self._wandb.summary).items()
|
||||||
if isinstance(v, numbers.Number) and not k.startswith("_")
|
if isinstance(v, numbers.Number) and not k.startswith("_")
|
||||||
}
|
}
|
||||||
if not args.load_best_model_at_end
|
if not args.load_best_model_at_end
|
||||||
@ -445,19 +425,21 @@ class WandbCallback(TrainerCallback):
|
|||||||
"train/total_floss": state.total_flos,
|
"train/total_floss": state.total_flos,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
||||||
for f in Path(temp_dir).glob("*"):
|
for f in Path(temp_dir).glob("*"):
|
||||||
if f.is_file():
|
if f.is_file():
|
||||||
with artifact.new_file(f.name, mode="wb") as fa:
|
with artifact.new_file(f.name, mode="wb") as fa:
|
||||||
fa.write(f.read_bytes())
|
fa.write(f.read_bytes())
|
||||||
wandb.run.log_artifact(artifact)
|
self._wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||||
|
if self._wandb is None:
|
||||||
|
return
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.setup(args, state, model, reinit=False)
|
self.setup(args, state, model, reinit=False)
|
||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
logs = rewrite_logs(logs)
|
logs = rewrite_logs(logs)
|
||||||
wandb.log(logs, step=state.global_step)
|
self._wandb.log(logs, step=state.global_step)
|
||||||
|
|
||||||
|
|
||||||
class CometCallback(TrainerCallback):
|
class CometCallback(TrainerCallback):
|
||||||
@ -522,10 +504,14 @@ class AzureMLCallback(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, azureml_run=None):
|
def __init__(self, azureml_run=None):
|
||||||
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
|
assert (
|
||||||
|
is_azureml_available()
|
||||||
|
), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
|
||||||
self.azureml_run = azureml_run
|
self.azureml_run = azureml_run
|
||||||
|
|
||||||
def on_init_end(self, args, state, control, **kwargs):
|
def on_init_end(self, args, state, control, **kwargs):
|
||||||
|
from azureml.core.run import Run
|
||||||
|
|
||||||
if self.azureml_run is None and state.is_world_process_zero:
|
if self.azureml_run is None and state.is_world_process_zero:
|
||||||
self.azureml_run = Run.get_context()
|
self.azureml_run = Run.get_context()
|
||||||
|
|
||||||
@ -544,9 +530,12 @@ class MLflowCallback(TrainerCallback):
|
|||||||
MAX_LOG_SIZE = 100
|
MAX_LOG_SIZE = 100
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
assert _has_mlflow, "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
||||||
|
import mlflow
|
||||||
|
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._log_artifacts = False
|
self._log_artifacts = False
|
||||||
|
self._ml_flow = mlflow
|
||||||
|
|
||||||
def setup(self, args, state, model):
|
def setup(self, args, state, model):
|
||||||
"""
|
"""
|
||||||
@ -564,7 +553,7 @@ class MLflowCallback(TrainerCallback):
|
|||||||
if log_artifacts in {"TRUE", "1"}:
|
if log_artifacts in {"TRUE", "1"}:
|
||||||
self._log_artifacts = True
|
self._log_artifacts = True
|
||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
mlflow.start_run()
|
self._ml_flow.start_run()
|
||||||
combined_dict = args.to_dict()
|
combined_dict = args.to_dict()
|
||||||
if hasattr(model, "config") and model.config is not None:
|
if hasattr(model, "config") and model.config is not None:
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
@ -572,7 +561,7 @@ class MLflowCallback(TrainerCallback):
|
|||||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||||
combined_dict_items = list(combined_dict.items())
|
combined_dict_items = list(combined_dict.items())
|
||||||
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
|
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
|
||||||
mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
@ -585,7 +574,7 @@ class MLflowCallback(TrainerCallback):
|
|||||||
if state.is_world_process_zero:
|
if state.is_world_process_zero:
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
if isinstance(v, (int, float)):
|
if isinstance(v, (int, float)):
|
||||||
mlflow.log_metric(k, v, step=state.global_step)
|
self._ml_flow.log_metric(k, v, step=state.global_step)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Trainer is attempting to log a value of "
|
"Trainer is attempting to log a value of "
|
||||||
@ -601,11 +590,11 @@ class MLflowCallback(TrainerCallback):
|
|||||||
if self._initialized and state.is_world_process_zero:
|
if self._initialized and state.is_world_process_zero:
|
||||||
if self._log_artifacts:
|
if self._log_artifacts:
|
||||||
logger.info("Logging artifacts. This may take time.")
|
logger.info("Logging artifacts. This may take time.")
|
||||||
mlflow.log_artifacts(args.output_dir)
|
self._ml_flow.log_artifacts(args.output_dir)
|
||||||
mlflow.end_run()
|
self._ml_flow.end_run()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
# if the previous run is not terminated correctly, the fluent API will
|
# if the previous run is not terminated correctly, the fluent API will
|
||||||
# not let you start a new run before the previous one is killed
|
# not let you start a new run before the previous one is killed
|
||||||
if mlflow.active_run is not None:
|
if self._ml_flow.active_run is not None:
|
||||||
mlflow.end_run(status="KILLED")
|
self._ml_flow.end_run(status="KILLED")
|
||||||
|
@ -25,18 +25,18 @@ from io import StringIO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
_datasets_available,
|
is_datasets_available,
|
||||||
_faiss_available,
|
is_faiss_available,
|
||||||
_flax_available,
|
is_flax_available,
|
||||||
_pandas_available,
|
is_pandas_available,
|
||||||
_scatter_available,
|
is_scatter_available,
|
||||||
_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
_tf_available,
|
is_tf_available,
|
||||||
_tokenizers_available,
|
is_tokenizers_available,
|
||||||
_torch_available,
|
is_torch_available,
|
||||||
_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
)
|
)
|
||||||
from .integrations import _has_optuna, _has_ray
|
from .integrations import is_optuna_available, is_ray_available
|
||||||
|
|
||||||
|
|
||||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||||
@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case):
|
|||||||
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
|
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available:
|
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
|
||||||
return unittest.skip("test is PT+TF test")(test_case)
|
return unittest.skip("test is PT+TF test")(test_case)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -166,7 +166,7 @@ def require_torch(test_case):
|
|||||||
These tests are skipped when PyTorch isn't installed.
|
These tests are skipped when PyTorch isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _torch_available:
|
if not is_torch_available():
|
||||||
return unittest.skip("test requires PyTorch")(test_case)
|
return unittest.skip("test requires PyTorch")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -179,7 +179,7 @@ def require_torch_scatter(test_case):
|
|||||||
These tests are skipped when PyTorch scatter isn't installed.
|
These tests are skipped when PyTorch scatter isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _scatter_available:
|
if not is_scatter_available():
|
||||||
return unittest.skip("test requires PyTorch scatter")(test_case)
|
return unittest.skip("test requires PyTorch scatter")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -192,7 +192,7 @@ def require_tf(test_case):
|
|||||||
These tests are skipped when TensorFlow isn't installed.
|
These tests are skipped when TensorFlow isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _tf_available:
|
if not is_tf_available():
|
||||||
return unittest.skip("test requires TensorFlow")(test_case)
|
return unittest.skip("test requires TensorFlow")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -205,7 +205,7 @@ def require_flax(test_case):
|
|||||||
These tests are skipped when one / both are not installed
|
These tests are skipped when one / both are not installed
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _flax_available:
|
if not is_flax_available():
|
||||||
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ def require_sentencepiece(test_case):
|
|||||||
These tests are skipped when SentencePiece isn't installed.
|
These tests are skipped when SentencePiece isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _sentencepiece_available:
|
if not is_sentencepiece_available():
|
||||||
return unittest.skip("test requires SentencePiece")(test_case)
|
return unittest.skip("test requires SentencePiece")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -230,7 +230,7 @@ def require_tokenizers(test_case):
|
|||||||
These tests are skipped when 🤗 Tokenizers isn't installed.
|
These tests are skipped when 🤗 Tokenizers isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _tokenizers_available:
|
if not is_tokenizers_available():
|
||||||
return unittest.skip("test requires tokenizers")(test_case)
|
return unittest.skip("test requires tokenizers")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -240,7 +240,7 @@ def require_pandas(test_case):
|
|||||||
"""
|
"""
|
||||||
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
||||||
"""
|
"""
|
||||||
if not _pandas_available:
|
if not is_pandas_available():
|
||||||
return unittest.skip("test requires pandas")(test_case)
|
return unittest.skip("test requires pandas")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -251,7 +251,7 @@ def require_scatter(test_case):
|
|||||||
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
||||||
installed.
|
installed.
|
||||||
"""
|
"""
|
||||||
if not _scatter_available:
|
if not is_scatter_available():
|
||||||
return unittest.skip("test requires PyTorch Scatter")(test_case)
|
return unittest.skip("test requires PyTorch Scatter")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case):
|
|||||||
|
|
||||||
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
|
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
|
||||||
"""
|
"""
|
||||||
if not _torch_available:
|
if not is_torch_available():
|
||||||
return unittest.skip("test requires PyTorch")(test_case)
|
return unittest.skip("test requires PyTorch")(test_case)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case):
|
|||||||
"""
|
"""
|
||||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||||
"""
|
"""
|
||||||
if not _torch_available:
|
if not is_torch_available():
|
||||||
return unittest.skip("test requires PyTorch")(test_case)
|
return unittest.skip("test requires PyTorch")(test_case)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -301,13 +301,13 @@ def require_torch_tpu(test_case):
|
|||||||
"""
|
"""
|
||||||
Decorator marking a test that requires a TPU (in PyTorch).
|
Decorator marking a test that requires a TPU (in PyTorch).
|
||||||
"""
|
"""
|
||||||
if not _torch_tpu_available:
|
if not is_torch_tpu_available():
|
||||||
return unittest.skip("test requires PyTorch TPU")
|
return unittest.skip("test requires PyTorch TPU")
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
if _torch_available:
|
if is_torch_available():
|
||||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -327,7 +327,7 @@ def require_torch_gpu(test_case):
|
|||||||
def require_datasets(test_case):
|
def require_datasets(test_case):
|
||||||
"""Decorator marking a test that requires datasets."""
|
"""Decorator marking a test that requires datasets."""
|
||||||
|
|
||||||
if not _datasets_available:
|
if not is_datasets_available():
|
||||||
return unittest.skip("test requires `datasets`")(test_case)
|
return unittest.skip("test requires `datasets`")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -335,7 +335,7 @@ def require_datasets(test_case):
|
|||||||
|
|
||||||
def require_faiss(test_case):
|
def require_faiss(test_case):
|
||||||
"""Decorator marking a test that requires faiss."""
|
"""Decorator marking a test that requires faiss."""
|
||||||
if not _faiss_available:
|
if not is_faiss_available():
|
||||||
return unittest.skip("test requires `faiss`")(test_case)
|
return unittest.skip("test requires `faiss`")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -348,7 +348,7 @@ def require_optuna(test_case):
|
|||||||
These tests are skipped when optuna isn't installed.
|
These tests are skipped when optuna isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _has_optuna:
|
if not is_optuna_available():
|
||||||
return unittest.skip("test requires optuna")(test_case)
|
return unittest.skip("test requires optuna")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -361,7 +361,7 @@ def require_ray(test_case):
|
|||||||
These tests are skipped when Ray/tune isn't installed.
|
These tests are skipped when Ray/tune isn't installed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not _has_ray:
|
if not is_ray_available():
|
||||||
return unittest.skip("test requires Ray/tune")(test_case)
|
return unittest.skip("test requires Ray/tune")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
@ -371,11 +371,11 @@ def get_gpu_count():
|
|||||||
"""
|
"""
|
||||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||||
"""
|
"""
|
||||||
if _torch_available:
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
elif _tf_available:
|
elif is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
return len(tf.config.list_physical_devices("GPU"))
|
return len(tf.config.list_physical_devices("GPU"))
|
||||||
|
@ -25,7 +25,7 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
# Integrations must be imported before ML frameworks:
|
# Integrations must be imported before ML frameworks:
|
||||||
@ -143,12 +143,6 @@ if is_mlflow_available():
|
|||||||
|
|
||||||
DEFAULT_CALLBACKS.append(MLflowCallback)
|
DEFAULT_CALLBACKS.append(MLflowCallback)
|
||||||
|
|
||||||
if is_optuna_available():
|
|
||||||
import optuna
|
|
||||||
|
|
||||||
if is_ray_tune_available():
|
|
||||||
from ray import tune
|
|
||||||
|
|
||||||
if is_azureml_available():
|
if is_azureml_available():
|
||||||
from .integrations import AzureMLCallback
|
from .integrations import AzureMLCallback
|
||||||
|
|
||||||
@ -159,6 +153,10 @@ if is_fairscale_available():
|
|||||||
from fairscale.optim import OSS
|
from fairscale.optim import OSS
|
||||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import optuna
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -611,15 +609,21 @@ class Trainer:
|
|||||||
return
|
return
|
||||||
self.objective = self.compute_objective(metrics.copy())
|
self.objective = self.compute_objective(metrics.copy())
|
||||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
|
import optuna
|
||||||
|
|
||||||
trial.report(self.objective, epoch)
|
trial.report(self.objective, epoch)
|
||||||
if trial.should_prune():
|
if trial.should_prune():
|
||||||
raise optuna.TrialPruned()
|
raise optuna.TrialPruned()
|
||||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
if self.state.global_step % self.args.save_steps == 0:
|
if self.state.global_step % self.args.save_steps == 0:
|
||||||
self._tune_save_checkpoint()
|
self._tune_save_checkpoint()
|
||||||
tune.report(objective=self.objective, **metrics)
|
tune.report(objective=self.objective, **metrics)
|
||||||
|
|
||||||
def _tune_save_checkpoint(self):
|
def _tune_save_checkpoint(self):
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
if not self.use_tune_checkpoints:
|
if not self.use_tune_checkpoints:
|
||||||
return
|
return
|
||||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||||
@ -981,7 +985,12 @@ class Trainer:
|
|||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
|
||||||
if self.hp_search_backend is not None and trial is not None:
|
if self.hp_search_backend is not None and trial is not None:
|
||||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||||
|
run_id = trial.number
|
||||||
|
else:
|
||||||
|
from ray import tune
|
||||||
|
|
||||||
|
run_id = tune.get_trial_id()
|
||||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||||
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||||
else:
|
else:
|
||||||
|
@ -18,14 +18,15 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import CamembertTokenizer, CamembertTokenizerFast
|
from transformers import CamembertTokenizer, CamembertTokenizerFast
|
||||||
from transformers.testing_utils import _torch_available, require_sentencepiece, require_tokenizers
|
from transformers.file_utils import is_torch_available
|
||||||
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
@ -21,10 +21,11 @@ from pathlib import Path
|
|||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
|
||||||
from transformers import BatchEncoding, MarianTokenizer
|
from transformers import BatchEncoding, MarianTokenizer
|
||||||
from transformers.testing_utils import _sentencepiece_available, _torch_available, require_sentencepiece
|
from transformers.file_utils import is_sentencepiece_available, is_torch_available
|
||||||
|
from transformers.testing_utils import require_sentencepiece
|
||||||
|
|
||||||
|
|
||||||
if _sentencepiece_available:
|
if is_sentencepiece_available():
|
||||||
from transformers.models.marian.tokenization_marian import save_json, vocab_files_names
|
from transformers.models.marian.tokenization_marian import save_json, vocab_files_names
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
@ -35,7 +36,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
|||||||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||||
zh_code = ">>zh<<"
|
zh_code = ">>zh<<"
|
||||||
ORG_NAME = "Helsinki-NLP/"
|
ORG_NAME = "Helsinki-NLP/"
|
||||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
@ -16,17 +16,13 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.file_utils import is_sentencepiece_available
|
||||||
_sentencepiece_available,
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||||
require_sentencepiece,
|
|
||||||
require_tokenizers,
|
|
||||||
require_torch,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
if _sentencepiece_available:
|
if is_sentencepiece_available():
|
||||||
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB
|
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,15 +17,15 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property, is_torch_available
|
||||||
from transformers.testing_utils import _torch_available, get_tests_dir, require_sentencepiece, require_tokenizers
|
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
Loading…
Reference in New Issue
Block a user