Fast transformers import part 1 (#9441)

* Don't import libs to check they are available

* Don't import integrations at init

* Add importlib_metdata to deps

* Remove old vars references

* Avoid syntax error

* Adapt testing utils

* Try to appease torchhub

* Add dependency

* Remove more private variables

* Fix typo

* Another typo

* Refine the tf availability test
This commit is contained in:
Sylvain Gugger 2021-01-06 12:17:24 -05:00 committed by GitHub
parent c89f1bc92e
commit 0c96262f7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 280 additions and 360 deletions

View File

@ -33,7 +33,7 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install torch pip install torch
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging importlib_metadata
- name: Torch hub list - name: Torch hub list
run: | run: |

View File

@ -30,7 +30,7 @@ from transformers import (
) )
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"] dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"]
@add_start_docstrings(AutoConfig.__doc__) @add_start_docstrings(AutoConfig.__doc__)

View File

@ -99,6 +99,7 @@ _deps = [
"flake8>=3.8.3", "flake8>=3.8.3",
"flax>=0.2.2", "flax>=0.2.2",
"fugashi>=1.0", "fugashi>=1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0", "ipadic>=1.0.0,<2.0",
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.0", "jax>=0.2.0",
@ -232,6 +233,7 @@ extras["dev"] = (
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [ install_requires = [
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
deps["numpy"], deps["numpy"],
deps["packaging"], # utilities from PyPA to e.g., compare versions deps["packaging"], # utilities from PyPA to e.g., compare versions

View File

@ -26,6 +26,8 @@ from .utils.versions import require_version_core
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split() pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
if sys.version_info < (3, 7): if sys.version_info < (3, 7):
pkgs_to_check_at_runtime.append("dataclasses") pkgs_to_check_at_runtime.append("dataclasses")
if sys.version_info < (3, 8):
pkgs_to_check_at_runtime.append("importlib_metadata")
for pkg in pkgs_to_check_at_runtime: for pkg in pkgs_to_check_at_runtime:
if pkg in deps: if pkg in deps:

View File

@ -12,6 +12,7 @@ deps = {
"flake8": "flake8>=3.8.3", "flake8": "flake8>=3.8.3",
"flax": "flax>=0.2.2", "flax": "flax>=0.2.2",
"fugashi": "fugashi>=1.0", "fugashi": "fugashi>=1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0", "ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.0", "jax": "jax>=0.2.0",

View File

@ -18,6 +18,7 @@ https://github.com/allenai/allennlp.
import copy import copy
import fnmatch import fnmatch
import importlib.util
import io import io
import json import json
import os import os
@ -37,8 +38,10 @@ from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
import numpy as np import numpy as np
from packaging import version
from tqdm.auto import tqdm from tqdm.auto import tqdm
import importlib_metadata
import requests import requests
from filelock import FileLock from filelock import FileLock
@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"} ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
try: USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
import torch
_torch_available = True # pylint: disable=invalid-name if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
logger.info("PyTorch version {} available.".format(torch.__version__)) _torch_available = importlib.util.find_spec("torch") is not None
else: if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set") logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False _torch_available = False
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
import tensorflow as tf _tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 # For the metadata, we have to look for both tensorflow and tensorflow-cpu
_tf_available = True # pylint: disable=invalid-name try:
logger.info("TensorFlow version {} available.".format(tf.__version__)) _tf_version = importlib_metadata.version("tensorflow")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tensorflow-cpu")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
_tf_available = False
else: else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set") logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False _tf_available = False
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try: if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: try:
import flax _jax_version = importlib_metadata.version("jax")
import jax _flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
logger.info("JAX version {}, Flax: available".format(jax.__version__)) except importlib_metadata.PackageNotFoundError:
logger.info("Flax available: {}".format(flax)) _flax_available = False
_flax_available = True else:
else:
_flax_available = False _flax_available = False
except ImportError:
_flax_available = False # pylint: disable=invalid-name
_datasets_available = importlib.util.find_spec("datasets") is not None
try: try:
import datasets # noqa: F401 # Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
# Check we're not importing a "datasets" directory somewhere _ = importlib_metadata.version("datasets")
_datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset") _datasets_metadata = importlib_metadata.metadata("datasets")
if _datasets_available: if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
logger.debug(f"Successfully imported datasets version {datasets.__version__}") _datasets_available = False
else: except importlib_metadata.PackageNotFoundError:
logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")
except ImportError:
_datasets_available = False _datasets_available = False
_faiss_available = importlib.util.find_spec("faiss") is not None
try: try:
from torch.hub import _get_torch_home _faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
torch_cache_home = _get_torch_home() except importlib_metadata.PackageNotFoundError:
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
try:
import torch_xla.core.xla_model as xm # noqa: F401
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
try:
import psutil # noqa: F401
_psutil_available = True
except ImportError:
_psutil_available = False
try:
import py3nvml # noqa: F401
_py3nvml_available = True
except ImportError:
_py3nvml_available = False
try:
from apex import amp # noqa: F401
_has_apex = True
except ImportError:
_has_apex = False
try:
import faiss # noqa: F401
_faiss_available = True
logger.debug(f"Successfully imported faiss version {faiss.__version__}")
except ImportError:
_faiss_available = False _faiss_available = False
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try: try:
import sklearn.metrics # noqa: F401 _scatter_version = importlib_metadata.version("torch_scatterr")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
import scipy.stats # noqa: F401 except importlib_metadata.PackageNotFoundError:
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
try:
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
import IPython # noqa: F401
_in_notebook = True
except (AttributeError, ImportError, KeyError):
_in_notebook = False
try:
import sentencepiece # noqa: F401
_sentencepiece_available = True
except ImportError:
_sentencepiece_available = False
try:
import google.protobuf # noqa: F401
_protobuf_available = True
except ImportError:
_protobuf_available = False
try:
import tokenizers # noqa: F401
_tokenizers_available = True
except ImportError:
_tokenizers_available = False
try:
import pandas # noqa: F401
_pandas_available = True
except ImportError:
_pandas_available = False
try:
import torch_scatter
# Check we're not importing a "torch_scatter" directory somewhere
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
if _scatter_available:
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
else:
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
except ImportError:
_scatter_available = False _scatter_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers") old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library # New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser( hf_cache_home = os.path.expanduser(
@ -308,7 +204,14 @@ def is_flax_available():
def is_torch_tpu_available(): def is_torch_tpu_available():
return _torch_tpu_available if not _torch_available:
return False
# This test is probably enough, but just in case, we unpack a bit.
if importlib.util.find_spec("torch_xla") is None:
return False
if importlib.util.find_spec("torch_xla.core") is None:
return False
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
def is_datasets_available(): def is_datasets_available():
@ -316,15 +219,15 @@ def is_datasets_available():
def is_psutil_available(): def is_psutil_available():
return _psutil_available return importlib.util.find_spec("psutil") is not None
def is_py3nvml_available(): def is_py3nvml_available():
return _py3nvml_available return importlib.util.find_spec("py3nvml") is not None
def is_apex_available(): def is_apex_available():
return _has_apex return importlib.util.find_spec("apex") is not None
def is_faiss_available(): def is_faiss_available():
@ -332,23 +235,39 @@ def is_faiss_available():
def is_sklearn_available(): def is_sklearn_available():
return _has_sklearn if importlib.util.find_spec("sklearn") is None:
return False
if importlib.util.find_spec("scipy") is None:
return False
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
def is_sentencepiece_available(): def is_sentencepiece_available():
return _sentencepiece_available return importlib.util.find_spec("sentencepiece") is not None
def is_protobuf_available(): def is_protobuf_available():
return _protobuf_available if importlib.util.find_spec("google") is None:
return False
return importlib.util.find_spec("google.protobuf") is not None
def is_tokenizers_available(): def is_tokenizers_available():
return _tokenizers_available return importlib.util.find_spec("tokenizers") is not None
def is_in_notebook(): def is_in_notebook():
return _in_notebook try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
return False
def is_scatter_available(): def is_scatter_available():
@ -356,7 +275,7 @@ def is_scatter_available():
def is_pandas_available(): def is_pandas_available():
return _pandas_available return importlib.util.find_spec("pandas") is not None
def torch_only_method(fn): def torch_only_method(fn):
@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
""" """
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available(): if is_torch_available():
ua += "; torch/{}".format(torch.__version__) ua += f"; torch/{_torch_version}"
if is_tf_available(): if is_tf_available():
ua += "; tensorflow/{}".format(tf.__version__) ua += f"; tensorflow/{_tf_version}"
if isinstance(user_agent, dict): if isinstance(user_agent, dict):
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str): elif isinstance(user_agent, str):

View File

@ -14,6 +14,7 @@
""" """
Integrations with other Python libraries. Integrations with other Python libraries.
""" """
import importlib.util
import math import math
import numbers import numbers
import os import os
@ -21,18 +22,16 @@ import re
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from .file_utils import ENV_VARS_TRUE_VALUES
from .trainer_utils import EvaluationStrategy
from .utils import logging from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Import 3rd-party integrations before ML frameworks: # comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
try: if _has_comet:
# Comet needs to be imported before any ML frameworks try:
import comet_ml # noqa: F401 import comet_ml # noqa: F401
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"): if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
@ -41,87 +40,20 @@ try:
if os.getenv("COMET_MODE", "").upper() != "DISABLED": if os.getenv("COMET_MODE", "").upper() != "DISABLED":
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.") logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
_has_comet = False _has_comet = False
except (ImportError, ValueError): except (ImportError, ValueError):
_has_comet = False _has_comet = False
try:
import wandb
wandb.ensure_configured() from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
if wandb.api.api_key is None:
_has_wandb = False
if os.getenv("WANDB_DISABLED"):
logger.warning("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
else:
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
except (ImportError, AttributeError):
_has_wandb = False
try:
import optuna # noqa: F401
_has_optuna = True
except (ImportError):
_has_optuna = False
try:
import ray # noqa: F401
_has_ray = True
try:
# Ray Tune has additional dependencies.
from ray import tune # noqa: F401
_has_ray_tune = True
except (ImportError):
_has_ray_tune = False
except (ImportError):
_has_ray = False
_has_ray_tune = False
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
try:
from azureml.core.run import Run # noqa: F401
_has_azureml = True
except ImportError:
_has_azureml = False
try:
import mlflow # noqa: F401
_has_mlflow = True
except ImportError:
_has_mlflow = False
try:
import fairscale # noqa: F401
_has_fairscale = True
except ImportError:
_has_fairscale = False
# No transformer imports above this point
from .file_utils import is_torch_tpu_available # noqa: E402
from .trainer_callback import TrainerCallback # noqa: E402 from .trainer_callback import TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # noqa: E402 from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
# Integration functions: # Integration functions:
def is_wandb_available(): def is_wandb_available():
return _has_wandb if os.getenv("WANDB_DISABLED"):
return False
return importlib.util.find_spec("wandb") is not None
def is_comet_available(): def is_comet_available():
@ -129,35 +61,43 @@ def is_comet_available():
def is_tensorboard_available(): def is_tensorboard_available():
return _has_tensorboard return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
def is_optuna_available(): def is_optuna_available():
return _has_optuna return importlib.util.find_spec("optuna") is not None
def is_ray_available(): def is_ray_available():
return _has_ray return importlib.util.find_spec("ray") is not None
def is_ray_tune_available(): def is_ray_tune_available():
return _has_ray_tune if not is_ray_available():
return False
return importlib.util.find_spec("ray.tune") is not None
def is_azureml_available(): def is_azureml_available():
return _has_azureml if importlib.util.find_spec("azureml") is None:
return False
if importlib.util.find_spec("azureml.core") is None:
return False
return importlib.util.find_spec("azureml.core.run") is not None
def is_mlflow_available(): def is_mlflow_available():
return _has_mlflow return importlib.util.find_spec("mlflow") is not None
def is_fairscale_available(): def is_fairscale_available():
return _has_fairscale return importlib.util.find_spec("fairscale") is not None
def hp_params(trial): def hp_params(trial):
if is_optuna_available(): if is_optuna_available():
import optuna
if isinstance(trial, optuna.Trial): if isinstance(trial, optuna.Trial):
return trial.params return trial.params
if is_ray_tune_available(): if is_ray_tune_available():
@ -175,6 +115,8 @@ def default_hp_search_backend():
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
def _objective(trial, checkpoint_dir=None): def _objective(trial, checkpoint_dir=None):
model_path = None model_path = None
if checkpoint_dir: if checkpoint_dir:
@ -198,6 +140,8 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
def _objective(trial, checkpoint_dir=None): def _objective(trial, checkpoint_dir=None):
model_path = None model_path = None
if checkpoint_dir: if checkpoint_dir:
@ -297,14 +241,29 @@ class TensorBoardCallback(TrainerCallback):
""" """
def __init__(self, tb_writer=None): def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
assert ( assert (
_has_tensorboard has_tensorboard
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX." ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
self._SummaryWriter = SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
self._SummaryWriter = SummaryWriter
except ImportError:
self._SummaryWriter = None
self.tb_writer = tb_writer self.tb_writer = tb_writer
self._SummaryWriter = SummaryWriter
def _init_summary_writer(self, args, log_dir=None): def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir log_dir = log_dir or args.logging_dir
self.tb_writer = SummaryWriter(log_dir=log_dir) if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs): def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero: if not state.is_world_process_zero:
@ -335,7 +294,7 @@ class TensorBoardCallback(TrainerCallback):
if self.tb_writer is None: if self.tb_writer is None:
self._init_summary_writer(args) self._init_summary_writer(args)
if self.tb_writer: if self.tb_writer is not None:
logs = rewrite_logs(logs) logs = rewrite_logs(logs)
for k, v in logs.items(): for k, v in logs.items():
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
@ -363,7 +322,20 @@ class WandbCallback(TrainerCallback):
""" """
def __init__(self): def __init__(self):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`." has_wandb = is_wandb_available()
assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
if has_wandb:
import wandb
wandb.ensure_configured()
if wandb.api.api_key is None:
has_wandb = False
logger.warning(
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
)
self._wandb = wandb
else:
self._wandb = None
self._initialized = False self._initialized = False
def setup(self, args, state, model, reinit, **kwargs): def setup(self, args, state, model, reinit, **kwargs):
@ -384,6 +356,8 @@ class WandbCallback(TrainerCallback):
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`): WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to disable wandb entirely. Whether or not to disable wandb entirely.
""" """
if self._wandb is None:
return
self._initialized = True self._initialized = True
if state.is_world_process_zero: if state.is_world_process_zero:
logger.info( logger.info(
@ -402,7 +376,7 @@ class WandbCallback(TrainerCallback):
else: else:
run_name = args.run_name run_name = args.run_name
wandb.init( self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict, config=combined_dict,
name=run_name, name=run_name,
@ -412,19 +386,25 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) self._wandb.watch(
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
)
# log outputs # log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search hp_search = state.is_hyper_param_search
if not self._initialized or hp_search: if not self._initialized or hp_search:
self.setup(args, state, model, reinit=hp_search, **kwargs) self.setup(args, state, model, reinit=hp_search, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
# commit last step # commit last step
wandb.log({}) self._wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero: if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer from .trainer import Trainer
@ -432,11 +412,11 @@ class WandbCallback(TrainerCallback):
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir) fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name # use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name) artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
metadata = ( metadata = (
{ {
k: v k: v
for k, v in dict(wandb.summary).items() for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_") if isinstance(v, numbers.Number) and not k.startswith("_")
} }
if not args.load_best_model_at_end if not args.load_best_model_at_end
@ -445,19 +425,21 @@ class WandbCallback(TrainerCallback):
"train/total_floss": state.total_flos, "train/total_floss": state.total_flos,
} }
) )
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"): for f in Path(temp_dir).glob("*"):
if f.is_file(): if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa: with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes()) fa.write(f.read_bytes())
wandb.run.log_artifact(artifact) self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs): def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if self._wandb is None:
return
if not self._initialized: if not self._initialized:
self.setup(args, state, model, reinit=False) self.setup(args, state, model, reinit=False)
if state.is_world_process_zero: if state.is_world_process_zero:
logs = rewrite_logs(logs) logs = rewrite_logs(logs)
wandb.log(logs, step=state.global_step) self._wandb.log(logs, step=state.global_step)
class CometCallback(TrainerCallback): class CometCallback(TrainerCallback):
@ -522,10 +504,14 @@ class AzureMLCallback(TrainerCallback):
""" """
def __init__(self, azureml_run=None): def __init__(self, azureml_run=None):
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`." assert (
is_azureml_available()
), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
self.azureml_run = azureml_run self.azureml_run = azureml_run
def on_init_end(self, args, state, control, **kwargs): def on_init_end(self, args, state, control, **kwargs):
from azureml.core.run import Run
if self.azureml_run is None and state.is_world_process_zero: if self.azureml_run is None and state.is_world_process_zero:
self.azureml_run = Run.get_context() self.azureml_run = Run.get_context()
@ -544,9 +530,12 @@ class MLflowCallback(TrainerCallback):
MAX_LOG_SIZE = 100 MAX_LOG_SIZE = 100
def __init__(self): def __init__(self):
assert _has_mlflow, "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`." assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
import mlflow
self._initialized = False self._initialized = False
self._log_artifacts = False self._log_artifacts = False
self._ml_flow = mlflow
def setup(self, args, state, model): def setup(self, args, state, model):
""" """
@ -564,7 +553,7 @@ class MLflowCallback(TrainerCallback):
if log_artifacts in {"TRUE", "1"}: if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True self._log_artifacts = True
if state.is_world_process_zero: if state.is_world_process_zero:
mlflow.start_run() self._ml_flow.start_run()
combined_dict = args.to_dict() combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None: if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict() model_config = model.config.to_dict()
@ -572,7 +561,7 @@ class MLflowCallback(TrainerCallback):
# MLflow cannot log more than 100 values in one go, so we have to split it # MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items()) combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE): for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE])) self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
self._initialized = True self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
@ -585,7 +574,7 @@ class MLflowCallback(TrainerCallback):
if state.is_world_process_zero: if state.is_world_process_zero:
for k, v in logs.items(): for k, v in logs.items():
if isinstance(v, (int, float)): if isinstance(v, (int, float)):
mlflow.log_metric(k, v, step=state.global_step) self._ml_flow.log_metric(k, v, step=state.global_step)
else: else:
logger.warning( logger.warning(
"Trainer is attempting to log a value of " "Trainer is attempting to log a value of "
@ -601,11 +590,11 @@ class MLflowCallback(TrainerCallback):
if self._initialized and state.is_world_process_zero: if self._initialized and state.is_world_process_zero:
if self._log_artifacts: if self._log_artifacts:
logger.info("Logging artifacts. This may take time.") logger.info("Logging artifacts. This may take time.")
mlflow.log_artifacts(args.output_dir) self._ml_flow.log_artifacts(args.output_dir)
mlflow.end_run() self._ml_flow.end_run()
def __del__(self): def __del__(self):
# if the previous run is not terminated correctly, the fluent API will # if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed # not let you start a new run before the previous one is killed
if mlflow.active_run is not None: if self._ml_flow.active_run is not None:
mlflow.end_run(status="KILLED") self._ml_flow.end_run(status="KILLED")

View File

@ -25,18 +25,18 @@ from io import StringIO
from pathlib import Path from pathlib import Path
from .file_utils import ( from .file_utils import (
_datasets_available, is_datasets_available,
_faiss_available, is_faiss_available,
_flax_available, is_flax_available,
_pandas_available, is_pandas_available,
_scatter_available, is_scatter_available,
_sentencepiece_available, is_sentencepiece_available,
_tf_available, is_tf_available,
_tokenizers_available, is_tokenizers_available,
_torch_available, is_torch_available,
_torch_tpu_available, is_torch_tpu_available,
) )
from .integrations import _has_optuna, _has_ray from .integrations import is_optuna_available, is_ray_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case):
to a truthy value and selecting the is_pt_tf_cross_test pytest mark. to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
""" """
if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available: if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
return unittest.skip("test is PT+TF test")(test_case) return unittest.skip("test is PT+TF test")(test_case)
else: else:
try: try:
@ -166,7 +166,7 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed. These tests are skipped when PyTorch isn't installed.
""" """
if not _torch_available: if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip("test requires PyTorch")(test_case)
else: else:
return test_case return test_case
@ -179,7 +179,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed. These tests are skipped when PyTorch scatter isn't installed.
""" """
if not _scatter_available: if not is_scatter_available():
return unittest.skip("test requires PyTorch scatter")(test_case) return unittest.skip("test requires PyTorch scatter")(test_case)
else: else:
return test_case return test_case
@ -192,7 +192,7 @@ def require_tf(test_case):
These tests are skipped when TensorFlow isn't installed. These tests are skipped when TensorFlow isn't installed.
""" """
if not _tf_available: if not is_tf_available():
return unittest.skip("test requires TensorFlow")(test_case) return unittest.skip("test requires TensorFlow")(test_case)
else: else:
return test_case return test_case
@ -205,7 +205,7 @@ def require_flax(test_case):
These tests are skipped when one / both are not installed These tests are skipped when one / both are not installed
""" """
if not _flax_available: if not is_flax_available():
test_case = unittest.skip("test requires JAX & Flax")(test_case) test_case = unittest.skip("test requires JAX & Flax")(test_case)
return test_case return test_case
@ -217,7 +217,7 @@ def require_sentencepiece(test_case):
These tests are skipped when SentencePiece isn't installed. These tests are skipped when SentencePiece isn't installed.
""" """
if not _sentencepiece_available: if not is_sentencepiece_available():
return unittest.skip("test requires SentencePiece")(test_case) return unittest.skip("test requires SentencePiece")(test_case)
else: else:
return test_case return test_case
@ -230,7 +230,7 @@ def require_tokenizers(test_case):
These tests are skipped when 🤗 Tokenizers isn't installed. These tests are skipped when 🤗 Tokenizers isn't installed.
""" """
if not _tokenizers_available: if not is_tokenizers_available():
return unittest.skip("test requires tokenizers")(test_case) return unittest.skip("test requires tokenizers")(test_case)
else: else:
return test_case return test_case
@ -240,7 +240,7 @@ def require_pandas(test_case):
""" """
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
""" """
if not _pandas_available: if not is_pandas_available():
return unittest.skip("test requires pandas")(test_case) return unittest.skip("test requires pandas")(test_case)
else: else:
return test_case return test_case
@ -251,7 +251,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed. installed.
""" """
if not _scatter_available: if not is_scatter_available():
return unittest.skip("test requires PyTorch Scatter")(test_case) return unittest.skip("test requires PyTorch Scatter")(test_case)
else: else:
return test_case return test_case
@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case):
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
""" """
if not _torch_available: if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip("test requires PyTorch")(test_case)
import torch import torch
@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case):
""" """
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch). Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
""" """
if not _torch_available: if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return unittest.skip("test requires PyTorch")(test_case)
import torch import torch
@ -301,13 +301,13 @@ def require_torch_tpu(test_case):
""" """
Decorator marking a test that requires a TPU (in PyTorch). Decorator marking a test that requires a TPU (in PyTorch).
""" """
if not _torch_tpu_available: if not is_torch_tpu_available():
return unittest.skip("test requires PyTorch TPU") return unittest.skip("test requires PyTorch TPU")
else: else:
return test_case return test_case
if _torch_available: if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch import torch
@ -327,7 +327,7 @@ def require_torch_gpu(test_case):
def require_datasets(test_case): def require_datasets(test_case):
"""Decorator marking a test that requires datasets.""" """Decorator marking a test that requires datasets."""
if not _datasets_available: if not is_datasets_available():
return unittest.skip("test requires `datasets`")(test_case) return unittest.skip("test requires `datasets`")(test_case)
else: else:
return test_case return test_case
@ -335,7 +335,7 @@ def require_datasets(test_case):
def require_faiss(test_case): def require_faiss(test_case):
"""Decorator marking a test that requires faiss.""" """Decorator marking a test that requires faiss."""
if not _faiss_available: if not is_faiss_available():
return unittest.skip("test requires `faiss`")(test_case) return unittest.skip("test requires `faiss`")(test_case)
else: else:
return test_case return test_case
@ -348,7 +348,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed. These tests are skipped when optuna isn't installed.
""" """
if not _has_optuna: if not is_optuna_available():
return unittest.skip("test requires optuna")(test_case) return unittest.skip("test requires optuna")(test_case)
else: else:
return test_case return test_case
@ -361,7 +361,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed. These tests are skipped when Ray/tune isn't installed.
""" """
if not _has_ray: if not is_ray_available():
return unittest.skip("test requires Ray/tune")(test_case) return unittest.skip("test requires Ray/tune")(test_case)
else: else:
return test_case return test_case
@ -371,11 +371,11 @@ def get_gpu_count():
""" """
Return the number of available gpus (regardless of whether torch or tf is used) Return the number of available gpus (regardless of whether torch or tf is used)
""" """
if _torch_available: if is_torch_available():
import torch import torch
return torch.cuda.device_count() return torch.cuda.device_count()
elif _tf_available: elif is_tf_available():
import tensorflow as tf import tensorflow as tf
return len(tf.config.list_physical_devices("GPU")) return len(tf.config.list_physical_devices("GPU"))

View File

@ -25,7 +25,7 @@ import shutil
import time import time
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
# Integrations must be imported before ML frameworks: # Integrations must be imported before ML frameworks:
@ -143,12 +143,6 @@ if is_mlflow_available():
DEFAULT_CALLBACKS.append(MLflowCallback) DEFAULT_CALLBACKS.append(MLflowCallback)
if is_optuna_available():
import optuna
if is_ray_tune_available():
from ray import tune
if is_azureml_available(): if is_azureml_available():
from .integrations import AzureMLCallback from .integrations import AzureMLCallback
@ -159,6 +153,10 @@ if is_fairscale_available():
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -611,15 +609,21 @@ class Trainer:
return return
self.objective = self.compute_objective(metrics.copy()) self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA: if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
trial.report(self.objective, epoch) trial.report(self.objective, epoch)
if trial.should_prune(): if trial.should_prune():
raise optuna.TrialPruned() raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY: elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
if self.state.global_step % self.args.save_steps == 0: if self.state.global_step % self.args.save_steps == 0:
self._tune_save_checkpoint() self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics) tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self): def _tune_save_checkpoint(self):
from ray import tune
if not self.use_tune_checkpoints: if not self.use_tune_checkpoints:
return return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
@ -981,7 +985,12 @@ class Trainer:
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None: if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
else:
from ray import tune
run_id = tune.get_trial_id()
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder) output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else: else:

View File

@ -18,14 +18,15 @@ import os
import unittest import unittest
from transformers import CamembertTokenizer, CamembertTokenizerFast from transformers import CamembertTokenizer, CamembertTokenizerFast
from transformers.testing_utils import _torch_available, require_sentencepiece, require_tokenizers from transformers.file_utils import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf" FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece @require_sentencepiece

View File

@ -21,10 +21,11 @@ from pathlib import Path
from shutil import copyfile from shutil import copyfile
from transformers import BatchEncoding, MarianTokenizer from transformers import BatchEncoding, MarianTokenizer
from transformers.testing_utils import _sentencepiece_available, _torch_available, require_sentencepiece from transformers.file_utils import is_sentencepiece_available, is_torch_available
from transformers.testing_utils import require_sentencepiece
if _sentencepiece_available: if is_sentencepiece_available():
from transformers.models.marian.tokenization_marian import save_json, vocab_files_names from transformers.models.marian.tokenization_marian import save_json, vocab_files_names
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
@ -35,7 +36,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<" zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/" ORG_NAME = "Helsinki-NLP/"
FRAMEWORK = "pt" if _torch_available else "tf" FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece @require_sentencepiece

View File

@ -16,17 +16,13 @@ import tempfile
import unittest import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
from transformers.testing_utils import ( from transformers.file_utils import is_sentencepiece_available
_sentencepiece_available, from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
require_sentencepiece,
require_tokenizers,
require_torch,
)
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
if _sentencepiece_available: if is_sentencepiece_available():
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB from .test_tokenization_xlm_roberta import SAMPLE_VOCAB

View File

@ -17,15 +17,15 @@
import unittest import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.file_utils import cached_property from transformers.file_utils import cached_property, is_torch_available
from transformers.testing_utils import _torch_available, get_tests_dir, require_sentencepiece, require_tokenizers from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf" FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece @require_sentencepiece