diff --git a/.github/workflows/github-torch-hub.yml b/.github/workflows/github-torch-hub.yml index 93b9c777bfe..71bdbe1da64 100644 --- a/.github/workflows/github-torch-hub.yml +++ b/.github/workflows/github-torch-hub.yml @@ -33,7 +33,7 @@ jobs: run: | pip install --upgrade pip pip install torch - pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging + pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging importlib_metadata - name: Torch hub list run: | diff --git a/hubconf.py b/hubconf.py index f9970e1a5f8..84924438210 100644 --- a/hubconf.py +++ b/hubconf.py @@ -30,7 +30,7 @@ from transformers import ( ) -dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"] +dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses", "importlib_metadata"] @add_start_docstrings(AutoConfig.__doc__) diff --git a/setup.py b/setup.py index 860cb8e4a12..e6a7d89b45a 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ _deps = [ "flake8>=3.8.3", "flax>=0.2.2", "fugashi>=1.0", + "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", "jax>=0.2.0", @@ -232,6 +233,7 @@ extras["dev"] = ( # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it + deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads deps["numpy"], deps["packaging"], # utilities from PyPA to e.g., compare versions diff --git a/src/transformers/dependency_versions_check.py b/src/transformers/dependency_versions_check.py index 76ee9a1e810..7e36aaef309 100644 --- a/src/transformers/dependency_versions_check.py +++ b/src/transformers/dependency_versions_check.py @@ -26,6 +26,8 @@ from .utils.versions import require_version_core pkgs_to_check_at_runtime = "python tqdm regex sacremoses requests packaging filelock numpy tokenizers".split() if sys.version_info < (3, 7): pkgs_to_check_at_runtime.append("dataclasses") +if sys.version_info < (3, 8): + pkgs_to_check_at_runtime.append("importlib_metadata") for pkg in pkgs_to_check_at_runtime: if pkg in deps: diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index b07c53058ff..69fc388f930 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -12,6 +12,7 @@ deps = { "flake8": "flake8>=3.8.3", "flax": "flax>=0.2.2", "fugashi": "fugashi>=1.0", + "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", "jax": "jax>=0.2.0", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 546d0e819bb..1f765c1a670 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -18,6 +18,7 @@ https://github.com/allenai/allennlp. import copy import fnmatch +import importlib.util import io import json import os @@ -37,8 +38,10 @@ from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile import numpy as np +from packaging import version from tqdm.auto import tqdm +import importlib_metadata import requests from filelock import FileLock @@ -52,195 +55,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logger.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + try: + _tf_version = importlib_metadata.version("tensorflow") + except importlib_metadata.PackageNotFoundError: + try: + _tf_version = importlib_metadata.version("tensorflow-cpu") + except importlib_metadata.PackageNotFoundError: + _tf_version = None + _tf_available = False + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.") + _tf_available = False + else: + logger.info(f"TensorFlow version {_tf_version} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + + +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None + if _flax_available: + try: + _jax_version = importlib_metadata.version("jax") + _flax_version = importlib_metadata.version("flax") + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + except importlib_metadata.PackageNotFoundError: + _flax_available = False +else: + _flax_available = False + + +_datasets_available = importlib.util.find_spec("datasets") is not None try: - USE_TF = os.environ.get("USE_TF", "AUTO").upper() - USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() - if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: - import torch - - _torch_available = True # pylint: disable=invalid-name - logger.info("PyTorch version {} available.".format(torch.__version__)) - else: - logger.info("Disabling PyTorch because USE_TF is set") - _torch_available = False -except ImportError: - _torch_available = False # pylint: disable=invalid-name - -try: - USE_TF = os.environ.get("USE_TF", "AUTO").upper() - USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() - - if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: - import tensorflow as tf - - assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 - _tf_available = True # pylint: disable=invalid-name - logger.info("TensorFlow version {} available.".format(tf.__version__)) - else: - logger.info("Disabling Tensorflow because USE_TORCH is set") - _tf_available = False -except (ImportError, AssertionError): - _tf_available = False # pylint: disable=invalid-name - - -try: - USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() - - if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: - import flax - import jax - - logger.info("JAX version {}, Flax: available".format(jax.__version__)) - logger.info("Flax available: {}".format(flax)) - _flax_available = True - else: - _flax_available = False -except ImportError: - _flax_available = False # pylint: disable=invalid-name - - -try: - import datasets # noqa: F401 - - # Check we're not importing a "datasets" directory somewhere - _datasets_available = hasattr(datasets, "__version__") and hasattr(datasets, "load_dataset") - if _datasets_available: - logger.debug(f"Successfully imported datasets version {datasets.__version__}") - else: - logger.debug("Imported a datasets object but this doesn't seem to be the 🤗 datasets library.") - -except ImportError: + # Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version + # AND checking it has an author field in the metadata that is HuggingFace. + _ = importlib_metadata.version("datasets") + _datasets_metadata = importlib_metadata.metadata("datasets") + if _datasets_metadata.get("author", "") != "HuggingFace Inc.": + _datasets_available = False +except importlib_metadata.PackageNotFoundError: _datasets_available = False + +_faiss_available = importlib.util.find_spec("faiss") is not None try: - from torch.hub import _get_torch_home - - torch_cache_home = _get_torch_home() -except ImportError: - torch_cache_home = os.path.expanduser( - os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) - ) - - -try: - import torch_xla.core.xla_model as xm # noqa: F401 - - if _torch_available: - _torch_tpu_available = True # pylint: disable= - else: - _torch_tpu_available = False -except ImportError: - _torch_tpu_available = False - - -try: - import psutil # noqa: F401 - - _psutil_available = True - -except ImportError: - _psutil_available = False - - -try: - import py3nvml # noqa: F401 - - _py3nvml_available = True - -except ImportError: - _py3nvml_available = False - - -try: - from apex import amp # noqa: F401 - - _has_apex = True -except ImportError: - _has_apex = False - - -try: - import faiss # noqa: F401 - - _faiss_available = True - logger.debug(f"Successfully imported faiss version {faiss.__version__}") -except ImportError: + _faiss_version = importlib_metadata.version("faiss") + logger.debug(f"Successfully imported faiss version {_faiss_version}") +except importlib_metadata.PackageNotFoundError: _faiss_available = False + +_scatter_available = importlib.util.find_spec("torch_scatter") is not None try: - import sklearn.metrics # noqa: F401 - - import scipy.stats # noqa: F401 - - _has_sklearn = True -except (AttributeError, ImportError): - _has_sklearn = False - -try: - # Test copied from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py - get_ipython = sys.modules["IPython"].get_ipython - if "IPKernelApp" not in get_ipython().config: - raise ImportError("console") - if "VSCODE_PID" in os.environ: - raise ImportError("vscode") - - import IPython # noqa: F401 - - _in_notebook = True -except (AttributeError, ImportError, KeyError): - _in_notebook = False - - -try: - import sentencepiece # noqa: F401 - - _sentencepiece_available = True - -except ImportError: - _sentencepiece_available = False - - -try: - import google.protobuf # noqa: F401 - - _protobuf_available = True - -except ImportError: - _protobuf_available = False - - -try: - import tokenizers # noqa: F401 - - _tokenizers_available = True - -except ImportError: - _tokenizers_available = False - - -try: - import pandas # noqa: F401 - - _pandas_available = True - -except ImportError: - _pandas_available = False - - -try: - import torch_scatter - - # Check we're not importing a "torch_scatter" directory somewhere - _scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter") - if _scatter_available: - logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}") - else: - logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.") - -except ImportError: + _scatter_version = importlib_metadata.version("torch_scatterr") + logger.debug(f"Successfully imported torch-scatter version {_scatter_version}") +except importlib_metadata.PackageNotFoundError: _scatter_available = False +torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) old_default_cache_path = os.path.join(torch_cache_home, "transformers") # New default cache, shared with the Datasets library hf_cache_home = os.path.expanduser( @@ -308,7 +204,14 @@ def is_flax_available(): def is_torch_tpu_available(): - return _torch_tpu_available + if not _torch_available: + return False + # This test is probably enough, but just in case, we unpack a bit. + if importlib.util.find_spec("torch_xla") is None: + return False + if importlib.util.find_spec("torch_xla.core") is None: + return False + return importlib.util.find_spec("torch_xla.core.xla_model") is not None def is_datasets_available(): @@ -316,15 +219,15 @@ def is_datasets_available(): def is_psutil_available(): - return _psutil_available + return importlib.util.find_spec("psutil") is not None def is_py3nvml_available(): - return _py3nvml_available + return importlib.util.find_spec("py3nvml") is not None def is_apex_available(): - return _has_apex + return importlib.util.find_spec("apex") is not None def is_faiss_available(): @@ -332,23 +235,39 @@ def is_faiss_available(): def is_sklearn_available(): - return _has_sklearn + if importlib.util.find_spec("sklearn") is None: + return False + if importlib.util.find_spec("scipy") is None: + return False + return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats") def is_sentencepiece_available(): - return _sentencepiece_available + return importlib.util.find_spec("sentencepiece") is not None def is_protobuf_available(): - return _protobuf_available + if importlib.util.find_spec("google") is None: + return False + return importlib.util.find_spec("google.protobuf") is not None def is_tokenizers_available(): - return _tokenizers_available + return importlib.util.find_spec("tokenizers") is not None def is_in_notebook(): - return _in_notebook + try: + # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py + get_ipython = sys.modules["IPython"].get_ipython + if "IPKernelApp" not in get_ipython().config: + raise ImportError("console") + if "VSCODE_PID" in os.environ: + raise ImportError("vscode") + + return importlib.util.find_spec("IPython") is not None + except (AttributeError, ImportError, KeyError): + return False def is_scatter_available(): @@ -356,7 +275,7 @@ def is_scatter_available(): def is_pandas_available(): - return _pandas_available + return importlib.util.find_spec("pandas") is not None def torch_only_method(fn): @@ -1167,9 +1086,9 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: """ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) if is_torch_available(): - ua += "; torch/{}".format(torch.__version__) + ua += f"; torch/{_torch_version}" if is_tf_available(): - ua += "; tensorflow/{}".format(tf.__version__) + ua += f"; tensorflow/{_tf_version}" if isinstance(user_agent, dict): ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items()) elif isinstance(user_agent, str): diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 4053582d3ae..db97827b04e 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -14,6 +14,7 @@ """ Integrations with other Python libraries. """ +import importlib.util import math import numbers import os @@ -21,107 +22,38 @@ import re import tempfile from pathlib import Path -from .file_utils import ENV_VARS_TRUE_VALUES -from .trainer_utils import EvaluationStrategy from .utils import logging logger = logging.get_logger(__name__) -# Import 3rd-party integrations before ML frameworks: +# comet_ml requires to be imported before any ML frameworks +_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED" +if _has_comet: + try: + import comet_ml # noqa: F401 -try: - # Comet needs to be imported before any ML frameworks - import comet_ml # noqa: F401 - - if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"): - _has_comet = True - else: - if os.getenv("COMET_MODE", "").upper() != "DISABLED": - logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.") + if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"): + _has_comet = True + else: + if os.getenv("COMET_MODE", "").upper() != "DISABLED": + logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.") + _has_comet = False + except (ImportError, ValueError): _has_comet = False -except (ImportError, ValueError): - _has_comet = False -try: - import wandb - wandb.ensure_configured() - if wandb.api.api_key is None: - _has_wandb = False - if os.getenv("WANDB_DISABLED"): - logger.warning("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.") - else: - _has_wandb = False if os.getenv("WANDB_DISABLED") else True -except (ImportError, AttributeError): - _has_wandb = False - -try: - import optuna # noqa: F401 - - _has_optuna = True -except (ImportError): - _has_optuna = False - -try: - import ray # noqa: F401 - - _has_ray = True - try: - # Ray Tune has additional dependencies. - from ray import tune # noqa: F401 - - _has_ray_tune = True - except (ImportError): - _has_ray_tune = False -except (ImportError): - _has_ray = False - _has_ray_tune = False - -try: - from torch.utils.tensorboard import SummaryWriter # noqa: F401 - - _has_tensorboard = True -except ImportError: - try: - from tensorboardX import SummaryWriter # noqa: F401 - - _has_tensorboard = True - except ImportError: - _has_tensorboard = False - -try: - from azureml.core.run import Run # noqa: F401 - - _has_azureml = True -except ImportError: - _has_azureml = False - -try: - import mlflow # noqa: F401 - - _has_mlflow = True -except ImportError: - _has_mlflow = False - -try: - import fairscale # noqa: F401 - - _has_fairscale = True -except ImportError: - _has_fairscale = False - -# No transformer imports above this point - -from .file_utils import is_torch_tpu_available # noqa: E402 +from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402 from .trainer_callback import TrainerCallback # noqa: E402 -from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # noqa: E402 +from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402 # Integration functions: def is_wandb_available(): - return _has_wandb + if os.getenv("WANDB_DISABLED"): + return False + return importlib.util.find_spec("wandb") is not None def is_comet_available(): @@ -129,35 +61,43 @@ def is_comet_available(): def is_tensorboard_available(): - return _has_tensorboard + return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None def is_optuna_available(): - return _has_optuna + return importlib.util.find_spec("optuna") is not None def is_ray_available(): - return _has_ray + return importlib.util.find_spec("ray") is not None def is_ray_tune_available(): - return _has_ray_tune + if not is_ray_available(): + return False + return importlib.util.find_spec("ray.tune") is not None def is_azureml_available(): - return _has_azureml + if importlib.util.find_spec("azureml") is None: + return False + if importlib.util.find_spec("azureml.core") is None: + return False + return importlib.util.find_spec("azureml.core.run") is not None def is_mlflow_available(): - return _has_mlflow + return importlib.util.find_spec("mlflow") is not None def is_fairscale_available(): - return _has_fairscale + return importlib.util.find_spec("fairscale") is not None def hp_params(trial): if is_optuna_available(): + import optuna + if isinstance(trial, optuna.Trial): return trial.params if is_ray_tune_available(): @@ -175,6 +115,8 @@ def default_hp_search_backend(): def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: + import optuna + def _objective(trial, checkpoint_dir=None): model_path = None if checkpoint_dir: @@ -198,6 +140,8 @@ def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> Be def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun: + import ray + def _objective(trial, checkpoint_dir=None): model_path = None if checkpoint_dir: @@ -297,14 +241,29 @@ class TensorBoardCallback(TrainerCallback): """ def __init__(self, tb_writer=None): + has_tensorboard = is_tensorboard_available() assert ( - _has_tensorboard + has_tensorboard ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX." + if has_tensorboard: + try: + from torch.utils.tensorboard import SummaryWriter # noqa: F401 + + self._SummaryWriter = SummaryWriter + except ImportError: + try: + from tensorboardX import SummaryWriter + + self._SummaryWriter = SummaryWriter + except ImportError: + self._SummaryWriter = None self.tb_writer = tb_writer + self._SummaryWriter = SummaryWriter def _init_summary_writer(self, args, log_dir=None): log_dir = log_dir or args.logging_dir - self.tb_writer = SummaryWriter(log_dir=log_dir) + if self._SummaryWriter is not None: + self.tb_writer = self._SummaryWriter(log_dir=log_dir) def on_train_begin(self, args, state, control, **kwargs): if not state.is_world_process_zero: @@ -335,7 +294,7 @@ class TensorBoardCallback(TrainerCallback): if self.tb_writer is None: self._init_summary_writer(args) - if self.tb_writer: + if self.tb_writer is not None: logs = rewrite_logs(logs) for k, v in logs.items(): if isinstance(v, (int, float)): @@ -363,7 +322,20 @@ class WandbCallback(TrainerCallback): """ def __init__(self): - assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`." + has_wandb = is_wandb_available() + assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`." + if has_wandb: + import wandb + + wandb.ensure_configured() + if wandb.api.api_key is None: + has_wandb = False + logger.warning( + "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable." + ) + self._wandb = wandb + else: + self._wandb = None self._initialized = False def setup(self, args, state, model, reinit, **kwargs): @@ -384,6 +356,8 @@ class WandbCallback(TrainerCallback): WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to disable wandb entirely. """ + if self._wandb is None: + return self._initialized = True if state.is_world_process_zero: logger.info( @@ -402,7 +376,7 @@ class WandbCallback(TrainerCallback): else: run_name = args.run_name - wandb.init( + self._wandb.init( project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=run_name, @@ -412,19 +386,25 @@ class WandbCallback(TrainerCallback): # keep track of model topology and gradients, unsupported on TPU if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": - wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) + self._wandb.watch( + model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps) + ) # log outputs self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) def on_train_begin(self, args, state, control, model=None, **kwargs): + if self._wandb is None: + return hp_search = state.is_hyper_param_search if not self._initialized or hp_search: self.setup(args, state, model, reinit=hp_search, **kwargs) def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): + if self._wandb is None: + return # commit last step - wandb.log({}) + self._wandb.log({}) if self._log_model and self._initialized and state.is_world_process_zero: from .trainer import Trainer @@ -432,11 +412,11 @@ class WandbCallback(TrainerCallback): with tempfile.TemporaryDirectory() as temp_dir: fake_trainer.save_model(temp_dir) # use run name and ensure it's a valid Artifact name - artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name) + artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name) metadata = ( { k: v - for k, v in dict(wandb.summary).items() + for k, v in dict(self._wandb.summary).items() if isinstance(v, numbers.Number) and not k.startswith("_") } if not args.load_best_model_at_end @@ -445,19 +425,21 @@ class WandbCallback(TrainerCallback): "train/total_floss": state.total_flos, } ) - artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) for f in Path(temp_dir).glob("*"): if f.is_file(): with artifact.new_file(f.name, mode="wb") as fa: fa.write(f.read_bytes()) - wandb.run.log_artifact(artifact) + self._wandb.run.log_artifact(artifact) def on_log(self, args, state, control, model=None, logs=None, **kwargs): + if self._wandb is None: + return if not self._initialized: self.setup(args, state, model, reinit=False) if state.is_world_process_zero: logs = rewrite_logs(logs) - wandb.log(logs, step=state.global_step) + self._wandb.log(logs, step=state.global_step) class CometCallback(TrainerCallback): @@ -522,10 +504,14 @@ class AzureMLCallback(TrainerCallback): """ def __init__(self, azureml_run=None): - assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`." + assert ( + is_azureml_available() + ), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`." self.azureml_run = azureml_run def on_init_end(self, args, state, control, **kwargs): + from azureml.core.run import Run + if self.azureml_run is None and state.is_world_process_zero: self.azureml_run = Run.get_context() @@ -544,9 +530,12 @@ class MLflowCallback(TrainerCallback): MAX_LOG_SIZE = 100 def __init__(self): - assert _has_mlflow, "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`." + assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`." + import mlflow + self._initialized = False self._log_artifacts = False + self._ml_flow = mlflow def setup(self, args, state, model): """ @@ -564,7 +553,7 @@ class MLflowCallback(TrainerCallback): if log_artifacts in {"TRUE", "1"}: self._log_artifacts = True if state.is_world_process_zero: - mlflow.start_run() + self._ml_flow.start_run() combined_dict = args.to_dict() if hasattr(model, "config") and model.config is not None: model_config = model.config.to_dict() @@ -572,7 +561,7 @@ class MLflowCallback(TrainerCallback): # MLflow cannot log more than 100 values in one go, so we have to split it combined_dict_items = list(combined_dict.items()) for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE): - mlflow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE])) + self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE])) self._initialized = True def on_train_begin(self, args, state, control, model=None, **kwargs): @@ -585,7 +574,7 @@ class MLflowCallback(TrainerCallback): if state.is_world_process_zero: for k, v in logs.items(): if isinstance(v, (int, float)): - mlflow.log_metric(k, v, step=state.global_step) + self._ml_flow.log_metric(k, v, step=state.global_step) else: logger.warning( "Trainer is attempting to log a value of " @@ -601,11 +590,11 @@ class MLflowCallback(TrainerCallback): if self._initialized and state.is_world_process_zero: if self._log_artifacts: logger.info("Logging artifacts. This may take time.") - mlflow.log_artifacts(args.output_dir) - mlflow.end_run() + self._ml_flow.log_artifacts(args.output_dir) + self._ml_flow.end_run() def __del__(self): # if the previous run is not terminated correctly, the fluent API will # not let you start a new run before the previous one is killed - if mlflow.active_run is not None: - mlflow.end_run(status="KILLED") + if self._ml_flow.active_run is not None: + self._ml_flow.end_run(status="KILLED") diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index aa92b1771ea..86ffb9f189e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -25,18 +25,18 @@ from io import StringIO from pathlib import Path from .file_utils import ( - _datasets_available, - _faiss_available, - _flax_available, - _pandas_available, - _scatter_available, - _sentencepiece_available, - _tf_available, - _tokenizers_available, - _torch_available, - _torch_tpu_available, + is_datasets_available, + is_faiss_available, + is_flax_available, + is_pandas_available, + is_scatter_available, + is_sentencepiece_available, + is_tf_available, + is_tokenizers_available, + is_torch_available, + is_torch_tpu_available, ) -from .integrations import _has_optuna, _has_ray +from .integrations import is_optuna_available, is_ray_available SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" @@ -90,7 +90,7 @@ def is_pt_tf_cross_test(test_case): to a truthy value and selecting the is_pt_tf_cross_test pytest mark. """ - if not _run_pt_tf_cross_tests or not _torch_available or not _tf_available: + if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available(): return unittest.skip("test is PT+TF test")(test_case) else: try: @@ -166,7 +166,7 @@ def require_torch(test_case): These tests are skipped when PyTorch isn't installed. """ - if not _torch_available: + if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) else: return test_case @@ -179,7 +179,7 @@ def require_torch_scatter(test_case): These tests are skipped when PyTorch scatter isn't installed. """ - if not _scatter_available: + if not is_scatter_available(): return unittest.skip("test requires PyTorch scatter")(test_case) else: return test_case @@ -192,7 +192,7 @@ def require_tf(test_case): These tests are skipped when TensorFlow isn't installed. """ - if not _tf_available: + if not is_tf_available(): return unittest.skip("test requires TensorFlow")(test_case) else: return test_case @@ -205,7 +205,7 @@ def require_flax(test_case): These tests are skipped when one / both are not installed """ - if not _flax_available: + if not is_flax_available(): test_case = unittest.skip("test requires JAX & Flax")(test_case) return test_case @@ -217,7 +217,7 @@ def require_sentencepiece(test_case): These tests are skipped when SentencePiece isn't installed. """ - if not _sentencepiece_available: + if not is_sentencepiece_available(): return unittest.skip("test requires SentencePiece")(test_case) else: return test_case @@ -230,7 +230,7 @@ def require_tokenizers(test_case): These tests are skipped when 🤗 Tokenizers isn't installed. """ - if not _tokenizers_available: + if not is_tokenizers_available(): return unittest.skip("test requires tokenizers")(test_case) else: return test_case @@ -240,7 +240,7 @@ def require_pandas(test_case): """ Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. """ - if not _pandas_available: + if not is_pandas_available(): return unittest.skip("test requires pandas")(test_case) else: return test_case @@ -251,7 +251,7 @@ def require_scatter(test_case): Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't installed. """ - if not _scatter_available: + if not is_scatter_available(): return unittest.skip("test requires PyTorch Scatter")(test_case) else: return test_case @@ -265,7 +265,7 @@ def require_torch_multi_gpu(test_case): To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" """ - if not _torch_available: + if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) import torch @@ -280,7 +280,7 @@ def require_torch_non_multi_gpu(test_case): """ Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch). """ - if not _torch_available: + if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) import torch @@ -301,13 +301,13 @@ def require_torch_tpu(test_case): """ Decorator marking a test that requires a TPU (in PyTorch). """ - if not _torch_tpu_available: + if not is_torch_tpu_available(): return unittest.skip("test requires PyTorch TPU") else: return test_case -if _torch_available: +if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch @@ -327,7 +327,7 @@ def require_torch_gpu(test_case): def require_datasets(test_case): """Decorator marking a test that requires datasets.""" - if not _datasets_available: + if not is_datasets_available(): return unittest.skip("test requires `datasets`")(test_case) else: return test_case @@ -335,7 +335,7 @@ def require_datasets(test_case): def require_faiss(test_case): """Decorator marking a test that requires faiss.""" - if not _faiss_available: + if not is_faiss_available(): return unittest.skip("test requires `faiss`")(test_case) else: return test_case @@ -348,7 +348,7 @@ def require_optuna(test_case): These tests are skipped when optuna isn't installed. """ - if not _has_optuna: + if not is_optuna_available(): return unittest.skip("test requires optuna")(test_case) else: return test_case @@ -361,7 +361,7 @@ def require_ray(test_case): These tests are skipped when Ray/tune isn't installed. """ - if not _has_ray: + if not is_ray_available(): return unittest.skip("test requires Ray/tune")(test_case) else: return test_case @@ -371,11 +371,11 @@ def get_gpu_count(): """ Return the number of available gpus (regardless of whether torch or tf is used) """ - if _torch_available: + if is_torch_available(): import torch return torch.cuda.device_count() - elif _tf_available: + elif is_tf_available(): import tensorflow as tf return len(tf.config.list_physical_devices("GPU")) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0c0f8ed9fc4..effa50b5a92 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -25,7 +25,7 @@ import shutil import time import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union # Integrations must be imported before ML frameworks: @@ -143,12 +143,6 @@ if is_mlflow_available(): DEFAULT_CALLBACKS.append(MLflowCallback) -if is_optuna_available(): - import optuna - -if is_ray_tune_available(): - from ray import tune - if is_azureml_available(): from .integrations import AzureMLCallback @@ -159,6 +153,10 @@ if is_fairscale_available(): from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler + +if TYPE_CHECKING: + import optuna + logger = logging.get_logger(__name__) @@ -611,15 +609,21 @@ class Trainer: return self.objective = self.compute_objective(metrics.copy()) if self.hp_search_backend == HPSearchBackend.OPTUNA: + import optuna + trial.report(self.objective, epoch) if trial.should_prune(): raise optuna.TrialPruned() elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + if self.state.global_step % self.args.save_steps == 0: self._tune_save_checkpoint() tune.report(objective=self.objective, **metrics) def _tune_save_checkpoint(self): + from ray import tune + if not self.use_tune_checkpoints: return with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: @@ -981,7 +985,12 @@ class Trainer: checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is not None and trial is not None: - run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + else: + from ray import tune + + run_id = tune.get_trial_id() run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder) else: diff --git a/tests/test_tokenization_camembert.py b/tests/test_tokenization_camembert.py index c64f91ede80..953806146db 100644 --- a/tests/test_tokenization_camembert.py +++ b/tests/test_tokenization_camembert.py @@ -18,14 +18,15 @@ import os import unittest from transformers import CamembertTokenizer, CamembertTokenizerFast -from transformers.testing_utils import _torch_available, require_sentencepiece, require_tokenizers +from transformers.file_utils import is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers from .test_tokenization_common import TokenizerTesterMixin SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") -FRAMEWORK = "pt" if _torch_available else "tf" +FRAMEWORK = "pt" if is_torch_available() else "tf" @require_sentencepiece diff --git a/tests/test_tokenization_marian.py b/tests/test_tokenization_marian.py index 7c50184f5c5..7f9e776a063 100644 --- a/tests/test_tokenization_marian.py +++ b/tests/test_tokenization_marian.py @@ -21,10 +21,11 @@ from pathlib import Path from shutil import copyfile from transformers import BatchEncoding, MarianTokenizer -from transformers.testing_utils import _sentencepiece_available, _torch_available, require_sentencepiece +from transformers.file_utils import is_sentencepiece_available, is_torch_available +from transformers.testing_utils import require_sentencepiece -if _sentencepiece_available: +if is_sentencepiece_available(): from transformers.models.marian.tokenization_marian import save_json, vocab_files_names from .test_tokenization_common import TokenizerTesterMixin @@ -35,7 +36,7 @@ SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/t mock_tokenizer_config = {"target_lang": "fi", "source_lang": "en"} zh_code = ">>zh<<" ORG_NAME = "Helsinki-NLP/" -FRAMEWORK = "pt" if _torch_available else "tf" +FRAMEWORK = "pt" if is_torch_available() else "tf" @require_sentencepiece diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index 5bd987eaba5..1376cd7e8bb 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -16,17 +16,13 @@ import tempfile import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, MBartTokenizer, MBartTokenizerFast, is_torch_available -from transformers.testing_utils import ( - _sentencepiece_available, - require_sentencepiece, - require_tokenizers, - require_torch, -) +from transformers.file_utils import is_sentencepiece_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch from .test_tokenization_common import TokenizerTesterMixin -if _sentencepiece_available: +if is_sentencepiece_available(): from .test_tokenization_xlm_roberta import SAMPLE_VOCAB diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 7ef4b931bf4..9fbd50eaf5e 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -17,15 +17,15 @@ import unittest from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast -from transformers.file_utils import cached_property -from transformers.testing_utils import _torch_available, get_tests_dir, require_sentencepiece, require_tokenizers +from transformers.file_utils import cached_property, is_torch_available +from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers from .test_tokenization_common import TokenizerTesterMixin SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") -FRAMEWORK = "pt" if _torch_available else "tf" +FRAMEWORK = "pt" if is_torch_available() else "tf" @require_sentencepiece