mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 06:10:04 +06:00
Fast transformers import part 1 (#9441)
* Don't import libs to check they are available * Don't import integrations at init * Add importlib_metdata to deps * Remove old vars references * Avoid syntax error * Adapt testing utils * Try to appease torchhub * Add dependency * Remove more private variables * Fix typo * Another typo * Refine the tf availability test
This commit is contained in:
parent
c89f1bc92e
commit
0c96262f7d
2
.github/workflows/github-torch-hub.yml
vendored
2
.github/workflows/github-torch-hub.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install torch
|
||||
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging
|
||||
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging importlib_metadata
|
||||
|
||||
- name: Torch hub list
|
||||
run: |
|
||||
|
@ -30,7 +30,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
|
||||
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
|
||||
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"]
|
||||
|
||||
|
||||
@add_start_docstrings(AutoConfig.__doc__)
|
||||
|
2
setup.py
2
setup.py
@ -99,6 +99,7 @@ _deps = [
|
||||
"flake8>=3.8.3",
|
||||
"flax>=0.2.2",
|
||||
"fugashi>=1.0",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.0",
|
||||
@ -232,6 +233,7 @@ extras["dev"] = (
|
||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||
install_requires = [
|
||||
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
||||
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
|
||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||
deps["numpy"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
|
@ -26,6 +26,8 @@ from .utils.versions import require_version_core
|
||||
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
|
||||
if sys.version_info < (3, 7):
|
||||
pkgs_to_check_at_runtime.append("dataclasses")
|
||||
if sys.version_info < (3, 8):
|
||||
pkgs_to_check_at_runtime.append("importlib_metadata")
|
||||
|
||||
for pkg in pkgs_to_check_at_runtime:
|
||||
if pkg in deps:
|
||||
|
@ -12,6 +12,7 @@ deps = {
|
||||
"flake8": "flake8>=3.8.3",
|
||||
"flax": "flax>=0.2.2",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.0",
|
||||
|
@ -18,6 +18,7 @@ https://github.com/allenai/allennlp.
|
||||
|
||||
import copy
|
||||
import fnmatch
|
||||
import importlib.util
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@ -37,8 +38,10 @@ from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import importlib_metadata
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
|
||||
@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
|
||||
try:
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
import torch
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
logger.info("PyTorch version {} available.".format(torch.__version__))
|
||||
else:
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
_torch_available = importlib.util.find_spec("torch") is not None
|
||||
if _torch_available:
|
||||
try:
|
||||
_torch_version = importlib_metadata.version("torch")
|
||||
logger.info(f"PyTorch version {_torch_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
_torch_available = False
|
||||
except ImportError:
|
||||
_torch_available = False # pylint: disable=invalid-name
|
||||
|
||||
try:
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
import tensorflow as tf
|
||||
|
||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||
_tf_available = True # pylint: disable=invalid-name
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
||||
if _tf_available:
|
||||
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tensorflow")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tensorflow-cpu")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_tf_version = None
|
||||
_tf_available = False
|
||||
if _tf_available:
|
||||
if version.parse(_tf_version) < version.parse("2"):
|
||||
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
|
||||
_tf_available = False
|
||||
else:
|
||||
logger.info(f"TensorFlow version {_tf_version} available.")
|
||||
else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
try:
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
import flax
|
||||
import jax
|
||||
|
||||
logger.info("JAX version {}, Flax: available".format(jax.__version__))
|
||||
logger.info("Flax available: {}".format(flax))
|
||||
_flax_available = True
|
||||
else:
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
if _flax_available:
|
||||
try:
|
||||
_jax_version = importlib_metadata.version("jax")
|
||||
_flax_version = importlib_metadata.version("flax")
|
||||
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_flax_available = False
|
||||
else:
|
||||
_flax_available = False
|
||||
except ImportError:
|
||||
_flax_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_datasets_available = importlib.util.find_spec("datasets") is not None
|
||||
try:
|
||||
import datasets # noqa: F401
|
||||
|
||||
# Check we're not importing a "datasets" directory somewhere
|
||||
_datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset")
|
||||
if _datasets_available:
|
||||
logger.debug(f"Successfully imported datasets version {datasets.__version__}")
|
||||
else:
|
||||
logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")
|
||||
|
||||
except ImportError:
|
||||
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
|
||||
# AND checking it has an author field in the metadata that is HuggingFace.
|
||||
_ = importlib_metadata.version("datasets")
|
||||
_datasets_metadata = importlib_metadata.metadata("datasets")
|
||||
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
|
||||
_datasets_available = False
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_datasets_available = False
|
||||
|
||||
|
||||
_faiss_available = importlib.util.find_spec("faiss") is not None
|
||||
try:
|
||||
from torch.hub import _get_torch_home
|
||||
|
||||
torch_cache_home = _get_torch_home()
|
||||
except ImportError:
|
||||
torch_cache_home = os.path.expanduser(
|
||||
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm # noqa: F401
|
||||
|
||||
if _torch_available:
|
||||
_torch_tpu_available = True # pylint: disable=
|
||||
else:
|
||||
_torch_tpu_available = False
|
||||
except ImportError:
|
||||
_torch_tpu_available = False
|
||||
|
||||
|
||||
try:
|
||||
import psutil # noqa: F401
|
||||
|
||||
_psutil_available = True
|
||||
|
||||
except ImportError:
|
||||
_psutil_available = False
|
||||
|
||||
|
||||
try:
|
||||
import py3nvml # noqa: F401
|
||||
|
||||
_py3nvml_available = True
|
||||
|
||||
except ImportError:
|
||||
_py3nvml_available = False
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp # noqa: F401
|
||||
|
||||
_has_apex = True
|
||||
except ImportError:
|
||||
_has_apex = False
|
||||
|
||||
|
||||
try:
|
||||
import faiss # noqa: F401
|
||||
|
||||
_faiss_available = True
|
||||
logger.debug(f"Successfully imported faiss version {faiss.__version__}")
|
||||
except ImportError:
|
||||
_faiss_version = importlib_metadata.version("faiss")
|
||||
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
import sklearn.metrics # noqa: F401
|
||||
|
||||
import scipy.stats # noqa: F401
|
||||
|
||||
_has_sklearn = True
|
||||
except (AttributeError, ImportError):
|
||||
_has_sklearn = False
|
||||
|
||||
try:
|
||||
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
get_ipython = sys.modules["IPython"].get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config:
|
||||
raise ImportError("console")
|
||||
if "VSCODE_PID" in os.environ:
|
||||
raise ImportError("vscode")
|
||||
|
||||
import IPython # noqa: F401
|
||||
|
||||
_in_notebook = True
|
||||
except (AttributeError, ImportError, KeyError):
|
||||
_in_notebook = False
|
||||
|
||||
|
||||
try:
|
||||
import sentencepiece # noqa: F401
|
||||
|
||||
_sentencepiece_available = True
|
||||
|
||||
except ImportError:
|
||||
_sentencepiece_available = False
|
||||
|
||||
|
||||
try:
|
||||
import google.protobuf # noqa: F401
|
||||
|
||||
_protobuf_available = True
|
||||
|
||||
except ImportError:
|
||||
_protobuf_available = False
|
||||
|
||||
|
||||
try:
|
||||
import tokenizers # noqa: F401
|
||||
|
||||
_tokenizers_available = True
|
||||
|
||||
except ImportError:
|
||||
_tokenizers_available = False
|
||||
|
||||
|
||||
try:
|
||||
import pandas # noqa: F401
|
||||
|
||||
_pandas_available = True
|
||||
|
||||
except ImportError:
|
||||
_pandas_available = False
|
||||
|
||||
|
||||
try:
|
||||
import torch_scatter
|
||||
|
||||
# Check we're not importing a "torch_scatter" directory somewhere
|
||||
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
|
||||
if _scatter_available:
|
||||
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
|
||||
else:
|
||||
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
|
||||
|
||||
except ImportError:
|
||||
_scatter_version = importlib_metadata.version("torch_scatterr")
|
||||
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scatter_available = False
|
||||
|
||||
|
||||
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
# New default cache, shared with the Datasets library
|
||||
hf_cache_home = os.path.expanduser(
|
||||
@ -308,7 +204,14 @@ def is_flax_available():
|
||||
|
||||
|
||||
def is_torch_tpu_available():
|
||||
return _torch_tpu_available
|
||||
if not _torch_available:
|
||||
return False
|
||||
# This test is probably enough, but just in case, we unpack a bit.
|
||||
if importlib.util.find_spec("torch_xla") is None:
|
||||
return False
|
||||
if importlib.util.find_spec("torch_xla.core") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
||||
|
||||
|
||||
def is_datasets_available():
|
||||
@ -316,15 +219,15 @@ def is_datasets_available():
|
||||
|
||||
|
||||
def is_psutil_available():
|
||||
return _psutil_available
|
||||
return importlib.util.find_spec("psutil") is not None
|
||||
|
||||
|
||||
def is_py3nvml_available():
|
||||
return _py3nvml_available
|
||||
return importlib.util.find_spec("py3nvml") is not None
|
||||
|
||||
|
||||
def is_apex_available():
|
||||
return _has_apex
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def is_faiss_available():
|
||||
@ -332,23 +235,39 @@ def is_faiss_available():
|
||||
|
||||
|
||||
def is_sklearn_available():
|
||||
return _has_sklearn
|
||||
if importlib.util.find_spec("sklearn") is None:
|
||||
return False
|
||||
if importlib.util.find_spec("scipy") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
|
||||
|
||||
|
||||
def is_sentencepiece_available():
|
||||
return _sentencepiece_available
|
||||
return importlib.util.find_spec("sentencepiece") is not None
|
||||
|
||||
|
||||
def is_protobuf_available():
|
||||
return _protobuf_available
|
||||
if importlib.util.find_spec("google") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("google.protobuf") is not None
|
||||
|
||||
|
||||
def is_tokenizers_available():
|
||||
return _tokenizers_available
|
||||
return importlib.util.find_spec("tokenizers") is not None
|
||||
|
||||
|
||||
def is_in_notebook():
|
||||
return _in_notebook
|
||||
try:
|
||||
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
||||
get_ipython = sys.modules["IPython"].get_ipython
|
||||
if "IPKernelApp" not in get_ipython().config:
|
||||
raise ImportError("console")
|
||||
if "VSCODE_PID" in os.environ:
|
||||
raise ImportError("vscode")
|
||||
|
||||
return importlib.util.find_spec("IPython") is not None
|
||||
except (AttributeError, ImportError, KeyError):
|
||||
return False
|
||||
|
||||
|
||||
def is_scatter_available():
|
||||
@ -356,7 +275,7 @@ def is_scatter_available():
|
||||
|
||||
|
||||
def is_pandas_available():
|
||||
return _pandas_available
|
||||
return importlib.util.find_spec("pandas") is not None
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
||||
if is_torch_available():
|
||||
ua += "; torch/{}".format(torch.__version__)
|
||||
ua += f"; torch/{_torch_version}"
|
||||
if is_tf_available():
|
||||
ua += "; tensorflow/{}".format(tf.__version__)
|
||||
ua += f"; tensorflow/{_tf_version}"
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""
|
||||
Integrations with other Python libraries.
|
||||
"""
|
||||
import importlib.util
|
||||
import math
|
||||
import numbers
|
||||
import os
|
||||
@ -21,18 +22,16 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .file_utils import ENV_VARS_TRUE_VALUES
|
||||
from .trainer_utils import EvaluationStrategy
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Import 3rd-party integrations before ML frameworks:
|
||||
|
||||
try:
|
||||
# Comet needs to be imported before any ML frameworks
|
||||
# comet_ml requires to be imported before any ML frameworks
|
||||
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
||||
if _has_comet:
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
|
||||
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
|
||||
@ -41,87 +40,20 @@ try:
|
||||
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
|
||||
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
|
||||
_has_comet = False
|
||||
except (ImportError, ValueError):
|
||||
except (ImportError, ValueError):
|
||||
_has_comet = False
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
_has_wandb = False
|
||||
if os.getenv("WANDB_DISABLED"):
|
||||
logger.warning("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
|
||||
else:
|
||||
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
|
||||
except (ImportError, AttributeError):
|
||||
_has_wandb = False
|
||||
|
||||
try:
|
||||
import optuna # noqa: F401
|
||||
|
||||
_has_optuna = True
|
||||
except (ImportError):
|
||||
_has_optuna = False
|
||||
|
||||
try:
|
||||
import ray # noqa: F401
|
||||
|
||||
_has_ray = True
|
||||
try:
|
||||
# Ray Tune has additional dependencies.
|
||||
from ray import tune # noqa: F401
|
||||
|
||||
_has_ray_tune = True
|
||||
except (ImportError):
|
||||
_has_ray_tune = False
|
||||
except (ImportError):
|
||||
_has_ray = False
|
||||
_has_ray_tune = False
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter # noqa: F401
|
||||
|
||||
_has_tensorboard = True
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
try:
|
||||
from azureml.core.run import Run # noqa: F401
|
||||
|
||||
_has_azureml = True
|
||||
except ImportError:
|
||||
_has_azureml = False
|
||||
|
||||
try:
|
||||
import mlflow # noqa: F401
|
||||
|
||||
_has_mlflow = True
|
||||
except ImportError:
|
||||
_has_mlflow = False
|
||||
|
||||
try:
|
||||
import fairscale # noqa: F401
|
||||
|
||||
_has_fairscale = True
|
||||
except ImportError:
|
||||
_has_fairscale = False
|
||||
|
||||
# No transformer imports above this point
|
||||
|
||||
from .file_utils import is_torch_tpu_available # noqa: E402
|
||||
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
|
||||
from .trainer_callback import TrainerCallback # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # noqa: E402
|
||||
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
|
||||
|
||||
|
||||
# Integration functions:
|
||||
def is_wandb_available():
|
||||
return _has_wandb
|
||||
if os.getenv("WANDB_DISABLED"):
|
||||
return False
|
||||
return importlib.util.find_spec("wandb") is not None
|
||||
|
||||
|
||||
def is_comet_available():
|
||||
@ -129,35 +61,43 @@ def is_comet_available():
|
||||
|
||||
|
||||
def is_tensorboard_available():
|
||||
return _has_tensorboard
|
||||
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
|
||||
|
||||
|
||||
def is_optuna_available():
|
||||
return _has_optuna
|
||||
return importlib.util.find_spec("optuna") is not None
|
||||
|
||||
|
||||
def is_ray_available():
|
||||
return _has_ray
|
||||
return importlib.util.find_spec("ray") is not None
|
||||
|
||||
|
||||
def is_ray_tune_available():
|
||||
return _has_ray_tune
|
||||
if not is_ray_available():
|
||||
return False
|
||||
return importlib.util.find_spec("ray.tune") is not None
|
||||
|
||||
|
||||
def is_azureml_available():
|
||||
return _has_azureml
|
||||
if importlib.util.find_spec("azureml") is None:
|
||||
return False
|
||||
if importlib.util.find_spec("azureml.core") is None:
|
||||
return False
|
||||
return importlib.util.find_spec("azureml.core.run") is not None
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return _has_mlflow
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def is_fairscale_available():
|
||||
return _has_fairscale
|
||||
return importlib.util.find_spec("fairscale") is not None
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
if isinstance(trial, optuna.Trial):
|
||||
return trial.params
|
||||
if is_ray_tune_available():
|
||||
@ -175,6 +115,8 @@ def default_hp_search_backend():
|
||||
|
||||
|
||||
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
import optuna
|
||||
|
||||
def _objective(trial, checkpoint_dir=None):
|
||||
model_path = None
|
||||
if checkpoint_dir:
|
||||
@ -198,6 +140,8 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
|
||||
|
||||
|
||||
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
|
||||
import ray
|
||||
|
||||
def _objective(trial, checkpoint_dir=None):
|
||||
model_path = None
|
||||
if checkpoint_dir:
|
||||
@ -297,14 +241,29 @@ class TensorBoardCallback(TrainerCallback):
|
||||
"""
|
||||
|
||||
def __init__(self, tb_writer=None):
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
assert (
|
||||
_has_tensorboard
|
||||
has_tensorboard
|
||||
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
|
||||
if has_tensorboard:
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter # noqa: F401
|
||||
|
||||
self._SummaryWriter = SummaryWriter
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
self._SummaryWriter = SummaryWriter
|
||||
except ImportError:
|
||||
self._SummaryWriter = None
|
||||
self.tb_writer = tb_writer
|
||||
self._SummaryWriter = SummaryWriter
|
||||
|
||||
def _init_summary_writer(self, args, log_dir=None):
|
||||
log_dir = log_dir or args.logging_dir
|
||||
self.tb_writer = SummaryWriter(log_dir=log_dir)
|
||||
if self._SummaryWriter is not None:
|
||||
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if not state.is_world_process_zero:
|
||||
@ -335,7 +294,7 @@ class TensorBoardCallback(TrainerCallback):
|
||||
if self.tb_writer is None:
|
||||
self._init_summary_writer(args)
|
||||
|
||||
if self.tb_writer:
|
||||
if self.tb_writer is not None:
|
||||
logs = rewrite_logs(logs)
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
@ -363,7 +322,20 @@ class WandbCallback(TrainerCallback):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||
has_wandb = is_wandb_available()
|
||||
assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
|
||||
if has_wandb:
|
||||
import wandb
|
||||
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key is None:
|
||||
has_wandb = False
|
||||
logger.warning(
|
||||
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
|
||||
)
|
||||
self._wandb = wandb
|
||||
else:
|
||||
self._wandb = None
|
||||
self._initialized = False
|
||||
|
||||
def setup(self, args, state, model, reinit, **kwargs):
|
||||
@ -384,6 +356,8 @@ class WandbCallback(TrainerCallback):
|
||||
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to disable wandb entirely.
|
||||
"""
|
||||
if self._wandb is None:
|
||||
return
|
||||
self._initialized = True
|
||||
if state.is_world_process_zero:
|
||||
logger.info(
|
||||
@ -402,7 +376,7 @@ class WandbCallback(TrainerCallback):
|
||||
else:
|
||||
run_name = args.run_name
|
||||
|
||||
wandb.init(
|
||||
self._wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"),
|
||||
config=combined_dict,
|
||||
name=run_name,
|
||||
@ -412,19 +386,25 @@ class WandbCallback(TrainerCallback):
|
||||
|
||||
# keep track of model topology and gradients, unsupported on TPU
|
||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
||||
self._wandb.watch(
|
||||
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
|
||||
)
|
||||
|
||||
# log outputs
|
||||
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
hp_search = state.is_hyper_param_search
|
||||
if not self._initialized or hp_search:
|
||||
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
||||
|
||||
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
# commit last step
|
||||
wandb.log({})
|
||||
self._wandb.log({})
|
||||
if self._log_model and self._initialized and state.is_world_process_zero:
|
||||
from .trainer import Trainer
|
||||
|
||||
@ -432,11 +412,11 @@ class WandbCallback(TrainerCallback):
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fake_trainer.save_model(temp_dir)
|
||||
# use run name and ensure it's a valid Artifact name
|
||||
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
|
||||
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
|
||||
metadata = (
|
||||
{
|
||||
k: v
|
||||
for k, v in dict(wandb.summary).items()
|
||||
for k, v in dict(self._wandb.summary).items()
|
||||
if isinstance(v, numbers.Number) and not k.startswith("_")
|
||||
}
|
||||
if not args.load_best_model_at_end
|
||||
@ -445,19 +425,21 @@ class WandbCallback(TrainerCallback):
|
||||
"train/total_floss": state.total_flos,
|
||||
}
|
||||
)
|
||||
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
||||
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with artifact.new_file(f.name, mode="wb") as fa:
|
||||
fa.write(f.read_bytes())
|
||||
wandb.run.log_artifact(artifact)
|
||||
self._wandb.run.log_artifact(artifact)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, reinit=False)
|
||||
if state.is_world_process_zero:
|
||||
logs = rewrite_logs(logs)
|
||||
wandb.log(logs, step=state.global_step)
|
||||
self._wandb.log(logs, step=state.global_step)
|
||||
|
||||
|
||||
class CometCallback(TrainerCallback):
|
||||
@ -522,10 +504,14 @@ class AzureMLCallback(TrainerCallback):
|
||||
"""
|
||||
|
||||
def __init__(self, azureml_run=None):
|
||||
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
|
||||
assert (
|
||||
is_azureml_available()
|
||||
), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
|
||||
self.azureml_run = azureml_run
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
from azureml.core.run import Run
|
||||
|
||||
if self.azureml_run is None and state.is_world_process_zero:
|
||||
self.azureml_run = Run.get_context()
|
||||
|
||||
@ -544,9 +530,12 @@ class MLflowCallback(TrainerCallback):
|
||||
MAX_LOG_SIZE = 100
|
||||
|
||||
def __init__(self):
|
||||
assert _has_mlflow, "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
||||
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
|
||||
import mlflow
|
||||
|
||||
self._initialized = False
|
||||
self._log_artifacts = False
|
||||
self._ml_flow = mlflow
|
||||
|
||||
def setup(self, args, state, model):
|
||||
"""
|
||||
@ -564,7 +553,7 @@ class MLflowCallback(TrainerCallback):
|
||||
if log_artifacts in {"TRUE", "1"}:
|
||||
self._log_artifacts = True
|
||||
if state.is_world_process_zero:
|
||||
mlflow.start_run()
|
||||
self._ml_flow.start_run()
|
||||
combined_dict = args.to_dict()
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
@ -572,7 +561,7 @@ class MLflowCallback(TrainerCallback):
|
||||
# MLflow cannot log more than 100 values in one go, so we have to split it
|
||||
combined_dict_items = list(combined_dict.items())
|
||||
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
|
||||
mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
||||
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
|
||||
self._initialized = True
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
@ -585,7 +574,7 @@ class MLflowCallback(TrainerCallback):
|
||||
if state.is_world_process_zero:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
mlflow.log_metric(k, v, step=state.global_step)
|
||||
self._ml_flow.log_metric(k, v, step=state.global_step)
|
||||
else:
|
||||
logger.warning(
|
||||
"Trainer is attempting to log a value of "
|
||||
@ -601,11 +590,11 @@ class MLflowCallback(TrainerCallback):
|
||||
if self._initialized and state.is_world_process_zero:
|
||||
if self._log_artifacts:
|
||||
logger.info("Logging artifacts. This may take time.")
|
||||
mlflow.log_artifacts(args.output_dir)
|
||||
mlflow.end_run()
|
||||
self._ml_flow.log_artifacts(args.output_dir)
|
||||
self._ml_flow.end_run()
|
||||
|
||||
def __del__(self):
|
||||
# if the previous run is not terminated correctly, the fluent API will
|
||||
# not let you start a new run before the previous one is killed
|
||||
if mlflow.active_run is not None:
|
||||
mlflow.end_run(status="KILLED")
|
||||
if self._ml_flow.active_run is not None:
|
||||
self._ml_flow.end_run(status="KILLED")
|
||||
|
@ -25,18 +25,18 @@ from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from .file_utils import (
|
||||
_datasets_available,
|
||||
_faiss_available,
|
||||
_flax_available,
|
||||
_pandas_available,
|
||||
_scatter_available,
|
||||
_sentencepiece_available,
|
||||
_tf_available,
|
||||
_tokenizers_available,
|
||||
_torch_available,
|
||||
_torch_tpu_available,
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_pandas_available,
|
||||
is_scatter_available,
|
||||
is_sentencepiece_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from .integrations import _has_optuna, _has_ray
|
||||
from .integrations import is_optuna_available, is_ray_available
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case):
|
||||
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
|
||||
|
||||
"""
|
||||
if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available:
|
||||
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
|
||||
return unittest.skip("test is PT+TF test")(test_case)
|
||||
else:
|
||||
try:
|
||||
@ -166,7 +166,7 @@ def require_torch(test_case):
|
||||
These tests are skipped when PyTorch isn't installed.
|
||||
|
||||
"""
|
||||
if not _torch_available:
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -179,7 +179,7 @@ def require_torch_scatter(test_case):
|
||||
These tests are skipped when PyTorch scatter isn't installed.
|
||||
|
||||
"""
|
||||
if not _scatter_available:
|
||||
if not is_scatter_available():
|
||||
return unittest.skip("test requires PyTorch scatter")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -192,7 +192,7 @@ def require_tf(test_case):
|
||||
These tests are skipped when TensorFlow isn't installed.
|
||||
|
||||
"""
|
||||
if not _tf_available:
|
||||
if not is_tf_available():
|
||||
return unittest.skip("test requires TensorFlow")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -205,7 +205,7 @@ def require_flax(test_case):
|
||||
These tests are skipped when one / both are not installed
|
||||
|
||||
"""
|
||||
if not _flax_available:
|
||||
if not is_flax_available():
|
||||
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
||||
return test_case
|
||||
|
||||
@ -217,7 +217,7 @@ def require_sentencepiece(test_case):
|
||||
These tests are skipped when SentencePiece isn't installed.
|
||||
|
||||
"""
|
||||
if not _sentencepiece_available:
|
||||
if not is_sentencepiece_available():
|
||||
return unittest.skip("test requires SentencePiece")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -230,7 +230,7 @@ def require_tokenizers(test_case):
|
||||
These tests are skipped when 🤗 Tokenizers isn't installed.
|
||||
|
||||
"""
|
||||
if not _tokenizers_available:
|
||||
if not is_tokenizers_available():
|
||||
return unittest.skip("test requires tokenizers")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -240,7 +240,7 @@ def require_pandas(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
||||
"""
|
||||
if not _pandas_available:
|
||||
if not is_pandas_available():
|
||||
return unittest.skip("test requires pandas")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -251,7 +251,7 @@ def require_scatter(test_case):
|
||||
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
||||
installed.
|
||||
"""
|
||||
if not _scatter_available:
|
||||
if not is_scatter_available():
|
||||
return unittest.skip("test requires PyTorch Scatter")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case):
|
||||
|
||||
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
|
||||
"""
|
||||
if not _torch_available:
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||
"""
|
||||
if not _torch_available:
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
@ -301,13 +301,13 @@ def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
"""
|
||||
if not _torch_tpu_available:
|
||||
if not is_torch_tpu_available():
|
||||
return unittest.skip("test requires PyTorch TPU")
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
if _torch_available:
|
||||
if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
|
||||
@ -327,7 +327,7 @@ def require_torch_gpu(test_case):
|
||||
def require_datasets(test_case):
|
||||
"""Decorator marking a test that requires datasets."""
|
||||
|
||||
if not _datasets_available:
|
||||
if not is_datasets_available():
|
||||
return unittest.skip("test requires `datasets`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -335,7 +335,7 @@ def require_datasets(test_case):
|
||||
|
||||
def require_faiss(test_case):
|
||||
"""Decorator marking a test that requires faiss."""
|
||||
if not _faiss_available:
|
||||
if not is_faiss_available():
|
||||
return unittest.skip("test requires `faiss`")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -348,7 +348,7 @@ def require_optuna(test_case):
|
||||
These tests are skipped when optuna isn't installed.
|
||||
|
||||
"""
|
||||
if not _has_optuna:
|
||||
if not is_optuna_available():
|
||||
return unittest.skip("test requires optuna")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -361,7 +361,7 @@ def require_ray(test_case):
|
||||
These tests are skipped when Ray/tune isn't installed.
|
||||
|
||||
"""
|
||||
if not _has_ray:
|
||||
if not is_ray_available():
|
||||
return unittest.skip("test requires Ray/tune")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
@ -371,11 +371,11 @@ def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||
"""
|
||||
if _torch_available:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
return torch.cuda.device_count()
|
||||
elif _tf_available:
|
||||
elif is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
return len(tf.config.list_physical_devices("GPU"))
|
||||
|
@ -25,7 +25,7 @@ import shutil
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
# Integrations must be imported before ML frameworks:
|
||||
@ -143,12 +143,6 @@ if is_mlflow_available():
|
||||
|
||||
DEFAULT_CALLBACKS.append(MLflowCallback)
|
||||
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
|
||||
if is_ray_tune_available():
|
||||
from ray import tune
|
||||
|
||||
if is_azureml_available():
|
||||
from .integrations import AzureMLCallback
|
||||
|
||||
@ -159,6 +153,10 @@ if is_fairscale_available():
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import optuna
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -611,15 +609,21 @@ class Trainer:
|
||||
return
|
||||
self.objective = self.compute_objective(metrics.copy())
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
import optuna
|
||||
|
||||
trial.report(self.objective, epoch)
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
elif self.hp_search_backend == HPSearchBackend.RAY:
|
||||
from ray import tune
|
||||
|
||||
if self.state.global_step % self.args.save_steps == 0:
|
||||
self._tune_save_checkpoint()
|
||||
tune.report(objective=self.objective, **metrics)
|
||||
|
||||
def _tune_save_checkpoint(self):
|
||||
from ray import tune
|
||||
|
||||
if not self.use_tune_checkpoints:
|
||||
return
|
||||
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
|
||||
@ -981,7 +985,12 @@ class Trainer:
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
|
||||
if self.hp_search_backend is not None and trial is not None:
|
||||
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
|
||||
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
||||
run_id = trial.number
|
||||
else:
|
||||
from ray import tune
|
||||
|
||||
run_id = tune.get_trial_id()
|
||||
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
|
||||
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
|
||||
else:
|
||||
|
@ -18,14 +18,15 @@ import os
|
||||
import unittest
|
||||
|
||||
from transformers import CamembertTokenizer, CamembertTokenizerFast
|
||||
from transformers.testing_utils import _torch_available, require_sentencepiece, require_tokenizers
|
||||
from transformers.file_utils import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
|
@ -21,10 +21,11 @@ from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import BatchEncoding, MarianTokenizer
|
||||
from transformers.testing_utils import _sentencepiece_available, _torch_available, require_sentencepiece
|
||||
from transformers.file_utils import is_sentencepiece_available, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece
|
||||
|
||||
|
||||
if _sentencepiece_available:
|
||||
if is_sentencepiece_available():
|
||||
from transformers.models.marian.tokenization_marian import save_json, vocab_files_names
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
@ -35,7 +36,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
|
||||
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
|
||||
zh_code = ">>zh<<"
|
||||
ORG_NAME = "Helsinki-NLP/"
|
||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
|
@ -16,17 +16,13 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
_sentencepiece_available,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.file_utils import is_sentencepiece_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
if _sentencepiece_available:
|
||||
if is_sentencepiece_available():
|
||||
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB
|
||||
|
||||
|
||||
|
@ -17,15 +17,15 @@
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import _torch_available, get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
from transformers.file_utils import cached_property, is_torch_available
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
|
Loading…
Reference in New Issue
Block a user