Fast transformers import part 1 (#9441)

* Don't import libs to check they are available

* Don't import integrations at init

* Add importlib_metdata to deps

* Remove old vars references

* Avoid syntax error

* Adapt testing utils

* Try to appease torchhub

* Add dependency

* Remove more private variables

* Fix typo

* Another typo

* Refine the tf availability test
This commit is contained in:
Sylvain Gugger 2021-01-06 12:17:24 -05:00 committed by GitHub
parent c89f1bc92e
commit 0c96262f7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 280 additions and 360 deletions

View File

@ -33,7 +33,7 @@ jobs:
run: |
pip install --upgrade pip
pip install torch
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging importlib_metadata
- name: Torch hub list
run: |

View File

@ -30,7 +30,7 @@ from transformers import (
)
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"]
@add_start_docstrings(AutoConfig.__doc__)

View File

@ -99,6 +99,7 @@ _deps = [
"flake8>=3.8.3",
"flax>=0.2.2",
"fugashi>=1.0",
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
"jax>=0.2.0",
@ -232,6 +233,7 @@ extras["dev"] = (
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
deps["numpy"],
deps["packaging"], # utilities from PyPA to e.g., compare versions

View File

@ -26,6 +26,8 @@ from .utils.versions import require_version_core
pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split()
if sys.version_info < (3, 7):
pkgs_to_check_at_runtime.append("dataclasses")
if sys.version_info < (3, 8):
pkgs_to_check_at_runtime.append("importlib_metadata")
for pkg in pkgs_to_check_at_runtime:
if pkg in deps:

View File

@ -12,6 +12,7 @@ deps = {
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.2.2",
"fugashi": "fugashi>=1.0",
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.0",

View File

@ -18,6 +18,7 @@ https://github.com/allenai/allennlp.
import copy
import fnmatch
import importlib.util
import io
import json
import os
@ -37,8 +38,10 @@ from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
import numpy as np
from packaging import version
from tqdm.auto import tqdm
import importlib_metadata
import requests
from filelock import FileLock
@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
import torch
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
_torch_available = True # pylint: disable=invalid-name
logger.info("PyTorch version {} available.".format(torch.__version__))
else:
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
try:
_torch_version = importlib_metadata.version("torch")
logger.info(f"PyTorch version {_torch_version} available.")
except importlib_metadata.PackageNotFoundError:
_torch_available = False
else:
logger.info("Disabling PyTorch because USE_TF is set")
_torch_available = False
except ImportError:
_torch_available = False # pylint: disable=invalid-name
try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
_tf_available = True # pylint: disable=invalid-name
logger.info("TensorFlow version {} available.".format(tf.__version__))
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
_tf_available = importlib.util.find_spec("tensorflow") is not None
if _tf_available:
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
try:
_tf_version = importlib_metadata.version("tensorflow")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tensorflow-cpu")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
_tf_available = False
else:
logger.info(f"TensorFlow version {_tf_version} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
except (ImportError, AssertionError):
_tf_available = False # pylint: disable=invalid-name
try:
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
import flax
import jax
logger.info("JAX version {}, Flax: available".format(jax.__version__))
logger.info("Flax available: {}".format(flax))
_flax_available = True
else:
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
try:
_jax_version = importlib_metadata.version("jax")
_flax_version = importlib_metadata.version("flax")
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
except importlib_metadata.PackageNotFoundError:
_flax_available = False
else:
_flax_available = False
except ImportError:
_flax_available = False # pylint: disable=invalid-name
_datasets_available = importlib.util.find_spec("datasets") is not None
try:
import datasets # noqa: F401
# Check we're not importing a "datasets" directory somewhere
_datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset")
if _datasets_available:
logger.debug(f"Successfully imported datasets version {datasets.__version__}")
else:
logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.")
except ImportError:
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
# AND checking it has an author field in the metadata that is HuggingFace.
_ = importlib_metadata.version("datasets")
_datasets_metadata = importlib_metadata.metadata("datasets")
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
_datasets_available = False
except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_faiss_available = importlib.util.find_spec("faiss") is not None
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
)
try:
import torch_xla.core.xla_model as xm # noqa: F401
if _torch_available:
_torch_tpu_available = True # pylint: disable=
else:
_torch_tpu_available = False
except ImportError:
_torch_tpu_available = False
try:
import psutil # noqa: F401
_psutil_available = True
except ImportError:
_psutil_available = False
try:
import py3nvml # noqa: F401
_py3nvml_available = True
except ImportError:
_py3nvml_available = False
try:
from apex import amp # noqa: F401
_has_apex = True
except ImportError:
_has_apex = False
try:
import faiss # noqa: F401
_faiss_available = True
logger.debug(f"Successfully imported faiss version {faiss.__version__}")
except ImportError:
_faiss_version = importlib_metadata.version("faiss")
logger.debug(f"Successfully imported faiss version {_faiss_version}")
except importlib_metadata.PackageNotFoundError:
_faiss_available = False
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
import sklearn.metrics # noqa: F401
import scipy.stats # noqa: F401
_has_sklearn = True
except (AttributeError, ImportError):
_has_sklearn = False
try:
# Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
import IPython # noqa: F401
_in_notebook = True
except (AttributeError, ImportError, KeyError):
_in_notebook = False
try:
import sentencepiece # noqa: F401
_sentencepiece_available = True
except ImportError:
_sentencepiece_available = False
try:
import google.protobuf # noqa: F401
_protobuf_available = True
except ImportError:
_protobuf_available = False
try:
import tokenizers # noqa: F401
_tokenizers_available = True
except ImportError:
_tokenizers_available = False
try:
import pandas # noqa: F401
_pandas_available = True
except ImportError:
_pandas_available = False
try:
import torch_scatter
# Check we're not importing a "torch_scatter" directory somewhere
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
if _scatter_available:
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
else:
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
except ImportError:
_scatter_version = importlib_metadata.version("torch_scatterr")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError:
_scatter_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
@ -308,7 +204,14 @@ def is_flax_available():
def is_torch_tpu_available():
return _torch_tpu_available
if not _torch_available:
return False
# This test is probably enough, but just in case, we unpack a bit.
if importlib.util.find_spec("torch_xla") is None:
return False
if importlib.util.find_spec("torch_xla.core") is None:
return False
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
def is_datasets_available():
@ -316,15 +219,15 @@ def is_datasets_available():
def is_psutil_available():
return _psutil_available
return importlib.util.find_spec("psutil") is not None
def is_py3nvml_available():
return _py3nvml_available
return importlib.util.find_spec("py3nvml") is not None
def is_apex_available():
return _has_apex
return importlib.util.find_spec("apex") is not None
def is_faiss_available():
@ -332,23 +235,39 @@ def is_faiss_available():
def is_sklearn_available():
return _has_sklearn
if importlib.util.find_spec("sklearn") is None:
return False
if importlib.util.find_spec("scipy") is None:
return False
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
def is_sentencepiece_available():
return _sentencepiece_available
return importlib.util.find_spec("sentencepiece") is not None
def is_protobuf_available():
return _protobuf_available
if importlib.util.find_spec("google") is None:
return False
return importlib.util.find_spec("google.protobuf") is not None
def is_tokenizers_available():
return _tokenizers_available
return importlib.util.find_spec("tokenizers") is not None
def is_in_notebook():
return _in_notebook
try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
get_ipython = sys.modules["IPython"].get_ipython
if "IPKernelApp" not in get_ipython().config:
raise ImportError("console")
if "VSCODE_PID" in os.environ:
raise ImportError("vscode")
return importlib.util.find_spec("IPython") is not None
except (AttributeError, ImportError, KeyError):
return False
def is_scatter_available():
@ -356,7 +275,7 @@ def is_scatter_available():
def is_pandas_available():
return _pandas_available
return importlib.util.find_spec("pandas") is not None
def torch_only_method(fn):
@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
ua += "; torch/{}".format(torch.__version__)
ua += f"; torch/{_torch_version}"
if is_tf_available():
ua += "; tensorflow/{}".format(tf.__version__)
ua += f"; tensorflow/{_tf_version}"
if isinstance(user_agent, dict):
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
elif isinstance(user_agent, str):

View File

@ -14,6 +14,7 @@
"""
Integrations with other Python libraries.
"""
import importlib.util
import math
import numbers
import os
@ -21,18 +22,16 @@ import re
import tempfile
from pathlib import Path
from .file_utils import ENV_VARS_TRUE_VALUES
from .trainer_utils import EvaluationStrategy
from .utils import logging
logger = logging.get_logger(__name__)
# Import 3rd-party integrations before ML frameworks:
try:
# Comet needs to be imported before any ML frameworks
# comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet:
try:
import comet_ml # noqa: F401
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
@ -41,87 +40,20 @@ try:
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
_has_comet = False
except (ImportError, ValueError):
except (ImportError, ValueError):
_has_comet = False
try:
import wandb
wandb.ensure_configured()
if wandb.api.api_key is None:
_has_wandb = False
if os.getenv("WANDB_DISABLED"):
logger.warning("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.")
else:
_has_wandb = False if os.getenv("WANDB_DISABLED") else True
except (ImportError, AttributeError):
_has_wandb = False
try:
import optuna # noqa: F401
_has_optuna = True
except (ImportError):
_has_optuna = False
try:
import ray # noqa: F401
_has_ray = True
try:
# Ray Tune has additional dependencies.
from ray import tune # noqa: F401
_has_ray_tune = True
except (ImportError):
_has_ray_tune = False
except (ImportError):
_has_ray = False
_has_ray_tune = False
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
try:
from tensorboardX import SummaryWriter # noqa: F401
_has_tensorboard = True
except ImportError:
_has_tensorboard = False
try:
from azureml.core.run import Run # noqa: F401
_has_azureml = True
except ImportError:
_has_azureml = False
try:
import mlflow # noqa: F401
_has_mlflow = True
except ImportError:
_has_mlflow = False
try:
import fairscale # noqa: F401
_has_fairscale = True
except ImportError:
_has_fairscale = False
# No transformer imports above this point
from .file_utils import is_torch_tpu_available # noqa: E402
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
# Integration functions:
def is_wandb_available():
return _has_wandb
if os.getenv("WANDB_DISABLED"):
return False
return importlib.util.find_spec("wandb") is not None
def is_comet_available():
@ -129,35 +61,43 @@ def is_comet_available():
def is_tensorboard_available():
return _has_tensorboard
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
def is_optuna_available():
return _has_optuna
return importlib.util.find_spec("optuna") is not None
def is_ray_available():
return _has_ray
return importlib.util.find_spec("ray") is not None
def is_ray_tune_available():
return _has_ray_tune
if not is_ray_available():
return False
return importlib.util.find_spec("ray.tune") is not None
def is_azureml_available():
return _has_azureml
if importlib.util.find_spec("azureml") is None:
return False
if importlib.util.find_spec("azureml.core") is None:
return False
return importlib.util.find_spec("azureml.core.run") is not None
def is_mlflow_available():
return _has_mlflow
return importlib.util.find_spec("mlflow") is not None
def is_fairscale_available():
return _has_fairscale
return importlib.util.find_spec("fairscale") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_tune_available():
@ -175,6 +115,8 @@ def default_hp_search_backend():
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import optuna
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
@ -198,6 +140,8 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be
def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
import ray
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
@ -297,14 +241,29 @@ class TensorBoardCallback(TrainerCallback):
"""
def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
assert (
_has_tensorboard
has_tensorboard
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
self._SummaryWriter = SummaryWriter
except ImportError:
try:
from tensorboardX import SummaryWriter
self._SummaryWriter = SummaryWriter
except ImportError:
self._SummaryWriter = None
self.tb_writer = tb_writer
self._SummaryWriter = SummaryWriter
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
self.tb_writer = SummaryWriter(log_dir=log_dir)
if self._SummaryWriter is not None:
self.tb_writer = self._SummaryWriter(log_dir=log_dir)
def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
@ -335,7 +294,7 @@ class TensorBoardCallback(TrainerCallback):
if self.tb_writer is None:
self._init_summary_writer(args)
if self.tb_writer:
if self.tb_writer is not None:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
@ -363,7 +322,20 @@ class WandbCallback(TrainerCallback):
"""
def __init__(self):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
has_wandb = is_wandb_available()
assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
if has_wandb:
import wandb
wandb.ensure_configured()
if wandb.api.api_key is None:
has_wandb = False
logger.warning(
"W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
)
self._wandb = wandb
else:
self._wandb = None
self._initialized = False
def setup(self, args, state, model, reinit, **kwargs):
@ -384,6 +356,8 @@ class WandbCallback(TrainerCallback):
WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to disable wandb entirely.
"""
if self._wandb is None:
return
self._initialized = True
if state.is_world_process_zero:
logger.info(
@ -402,7 +376,7 @@ class WandbCallback(TrainerCallback):
else:
run_name = args.run_name
wandb.init(
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
@ -412,19 +386,25 @@ class WandbCallback(TrainerCallback):
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
self._wandb.watch(
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
)
# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
self.setup(args, state, model, reinit=hp_search, **kwargs)
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
# commit last step
wandb.log({})
self._wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer
@ -432,11 +412,11 @@ class WandbCallback(TrainerCallback):
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
metadata = (
{
k: v
for k, v in dict(wandb.summary).items()
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
@ -445,19 +425,21 @@ class WandbCallback(TrainerCallback):
"train/total_floss": state.total_flos,
}
)
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
wandb.run.log_artifact(artifact)
self._wandb.run.log_artifact(artifact)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, model, reinit=False)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
wandb.log(logs, step=state.global_step)
self._wandb.log(logs, step=state.global_step)
class CometCallback(TrainerCallback):
@ -522,10 +504,14 @@ class AzureMLCallback(TrainerCallback):
"""
def __init__(self, azureml_run=None):
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
assert (
is_azureml_available()
), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
self.azureml_run = azureml_run
def on_init_end(self, args, state, control, **kwargs):
from azureml.core.run import Run
if self.azureml_run is None and state.is_world_process_zero:
self.azureml_run = Run.get_context()
@ -544,9 +530,12 @@ class MLflowCallback(TrainerCallback):
MAX_LOG_SIZE = 100
def __init__(self):
assert _has_mlflow, "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
import mlflow
self._initialized = False
self._log_artifacts = False
self._ml_flow = mlflow
def setup(self, args, state, model):
"""
@ -564,7 +553,7 @@ class MLflowCallback(TrainerCallback):
if log_artifacts in {"TRUE", "1"}:
self._log_artifacts = True
if state.is_world_process_zero:
mlflow.start_run()
self._ml_flow.start_run()
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
@ -572,7 +561,7 @@ class MLflowCallback(TrainerCallback):
# MLflow cannot log more than 100 values in one go, so we have to split it
combined_dict_items = list(combined_dict.items())
for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs):
@ -585,7 +574,7 @@ class MLflowCallback(TrainerCallback):
if state.is_world_process_zero:
for k, v in logs.items():
if isinstance(v, (int, float)):
mlflow.log_metric(k, v, step=state.global_step)
self._ml_flow.log_metric(k, v, step=state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
@ -601,11 +590,11 @@ class MLflowCallback(TrainerCallback):
if self._initialized and state.is_world_process_zero:
if self._log_artifacts:
logger.info("Logging artifacts. This may take time.")
mlflow.log_artifacts(args.output_dir)
mlflow.end_run()
self._ml_flow.log_artifacts(args.output_dir)
self._ml_flow.end_run()
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if mlflow.active_run is not None:
mlflow.end_run(status="KILLED")
if self._ml_flow.active_run is not None:
self._ml_flow.end_run(status="KILLED")

View File

@ -25,18 +25,18 @@ from io import StringIO
from pathlib import Path
from .file_utils import (
_datasets_available,
_faiss_available,
_flax_available,
_pandas_available,
_scatter_available,
_sentencepiece_available,
_tf_available,
_tokenizers_available,
_torch_available,
_torch_tpu_available,
is_datasets_available,
is_faiss_available,
is_flax_available,
is_pandas_available,
is_scatter_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
)
from .integrations import _has_optuna, _has_ray
from .integrations import is_optuna_available, is_ray_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case):
to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
"""
if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available:
if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
return unittest.skip("test is PT+TF test")(test_case)
else:
try:
@ -166,7 +166,7 @@ def require_torch(test_case):
These tests are skipped when PyTorch isn't installed.
"""
if not _torch_available:
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
else:
return test_case
@ -179,7 +179,7 @@ def require_torch_scatter(test_case):
These tests are skipped when PyTorch scatter isn't installed.
"""
if not _scatter_available:
if not is_scatter_available():
return unittest.skip("test requires PyTorch scatter")(test_case)
else:
return test_case
@ -192,7 +192,7 @@ def require_tf(test_case):
These tests are skipped when TensorFlow isn't installed.
"""
if not _tf_available:
if not is_tf_available():
return unittest.skip("test requires TensorFlow")(test_case)
else:
return test_case
@ -205,7 +205,7 @@ def require_flax(test_case):
These tests are skipped when one / both are not installed
"""
if not _flax_available:
if not is_flax_available():
test_case = unittest.skip("test requires JAX & Flax")(test_case)
return test_case
@ -217,7 +217,7 @@ def require_sentencepiece(test_case):
These tests are skipped when SentencePiece isn't installed.
"""
if not _sentencepiece_available:
if not is_sentencepiece_available():
return unittest.skip("test requires SentencePiece")(test_case)
else:
return test_case
@ -230,7 +230,7 @@ def require_tokenizers(test_case):
These tests are skipped when 🤗 Tokenizers isn't installed.
"""
if not _tokenizers_available:
if not is_tokenizers_available():
return unittest.skip("test requires tokenizers")(test_case)
else:
return test_case
@ -240,7 +240,7 @@ def require_pandas(test_case):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
"""
if not _pandas_available:
if not is_pandas_available():
return unittest.skip("test requires pandas")(test_case)
else:
return test_case
@ -251,7 +251,7 @@ def require_scatter(test_case):
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
if not _scatter_available:
if not is_scatter_available():
return unittest.skip("test requires PyTorch Scatter")(test_case)
else:
return test_case
@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case):
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""
if not _torch_available:
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case):
"""
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
"""
if not _torch_available:
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
@ -301,13 +301,13 @@ def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
if not _torch_tpu_available:
if not is_torch_tpu_available():
return unittest.skip("test requires PyTorch TPU")
else:
return test_case
if _torch_available:
if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
@ -327,7 +327,7 @@ def require_torch_gpu(test_case):
def require_datasets(test_case):
"""Decorator marking a test that requires datasets."""
if not _datasets_available:
if not is_datasets_available():
return unittest.skip("test requires `datasets`")(test_case)
else:
return test_case
@ -335,7 +335,7 @@ def require_datasets(test_case):
def require_faiss(test_case):
"""Decorator marking a test that requires faiss."""
if not _faiss_available:
if not is_faiss_available():
return unittest.skip("test requires `faiss`")(test_case)
else:
return test_case
@ -348,7 +348,7 @@ def require_optuna(test_case):
These tests are skipped when optuna isn't installed.
"""
if not _has_optuna:
if not is_optuna_available():
return unittest.skip("test requires optuna")(test_case)
else:
return test_case
@ -361,7 +361,7 @@ def require_ray(test_case):
These tests are skipped when Ray/tune isn't installed.
"""
if not _has_ray:
if not is_ray_available():
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case
@ -371,11 +371,11 @@ def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch or tf is used)
"""
if _torch_available:
if is_torch_available():
import torch
return torch.cuda.device_count()
elif _tf_available:
elif is_tf_available():
import tensorflow as tf
return len(tf.config.list_physical_devices("GPU"))

View File

@ -25,7 +25,7 @@ import shutil
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
# Integrations must be imported before ML frameworks:
@ -143,12 +143,6 @@ if is_mlflow_available():
DEFAULT_CALLBACKS.append(MLflowCallback)
if is_optuna_available():
import optuna
if is_ray_tune_available():
from ray import tune
if is_azureml_available():
from .integrations import AzureMLCallback
@ -159,6 +153,10 @@ if is_fairscale_available():
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if TYPE_CHECKING:
import optuna
logger = logging.get_logger(__name__)
@ -611,15 +609,21 @@ class Trainer:
return
self.objective = self.compute_objective(metrics.copy())
if self.hp_search_backend == HPSearchBackend.OPTUNA:
import optuna
trial.report(self.objective, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
elif self.hp_search_backend == HPSearchBackend.RAY:
from ray import tune
if self.state.global_step % self.args.save_steps == 0:
self._tune_save_checkpoint()
tune.report(objective=self.objective, **metrics)
def _tune_save_checkpoint(self):
from ray import tune
if not self.use_tune_checkpoints:
return
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
@ -981,7 +985,12 @@ class Trainer:
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
if self.hp_search_backend == HPSearchBackend.OPTUNA:
run_id = trial.number
else:
from ray import tune
run_id = tune.get_trial_id()
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else:

View File

@ -18,14 +18,15 @@ import os
import unittest
from transformers import CamembertTokenizer, CamembertTokenizerFast
from transformers.testing_utils import _torch_available, require_sentencepiece, require_tokenizers
from transformers.file_utils import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf"
FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece

View File

@ -21,10 +21,11 @@ from pathlib import Path
from shutil import copyfile
from transformers import BatchEncoding, MarianTokenizer
from transformers.testing_utils import _sentencepiece_available, _torch_available, require_sentencepiece
from transformers.file_utils import is_sentencepiece_available, is_torch_available
from transformers.testing_utils import require_sentencepiece
if _sentencepiece_available:
if is_sentencepiece_available():
from transformers.models.marian.tokenization_marian import save_json, vocab_files_names
from .test_tokenization_common import TokenizerTesterMixin
@ -35,7 +36,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t
mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"}
zh_code = ">>zh<<"
ORG_NAME = "Helsinki-NLP/"
FRAMEWORK = "pt" if _torch_available else "tf"
FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece

View File

@ -16,17 +16,13 @@ import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available
from transformers.testing_utils import (
_sentencepiece_available,
require_sentencepiece,
require_tokenizers,
require_torch,
)
from transformers.file_utils import is_sentencepiece_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch
from .test_tokenization_common import TokenizerTesterMixin
if _sentencepiece_available:
if is_sentencepiece_available():
from .test_tokenization_xlm_roberta import SAMPLE_VOCAB

View File

@ -17,15 +17,15 @@
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.file_utils import cached_property
from transformers.testing_utils import _torch_available, get_tests_dir, require_sentencepiece, require_tokenizers
from transformers.file_utils import cached_property, is_torch_available
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
from .test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
FRAMEWORK = "pt" if _torch_available else "tf"
FRAMEWORK = "pt" if is_torch_available() else "tf"
@require_sentencepiece