Add AzureML in integrations via dedicated callback (#8062)

* first attempt to add AzureML callbacks

* func arg fix

* var name fix, but still won't fix error...

* fixing as in https://discuss.huggingface.co/t/how-to-integrate-an-azuremlcallback-for-logging-in-azure/1713/2

* Avoid lint check of azureml import

* black compliance

* Make isort happy

* Fix point typo in docs

* Add AzureML to Callbacks docs

* Attempt to make sphinx happy

* Format callback docs

* Make documentation style happy

* Make docs compliant to style

Co-authored-by: Davide Fiocco <davide.fiocco@frontiersin.net>
This commit is contained in:
Davide Fiocco 2020-10-27 19:21:54 +01:00 committed by GitHub
parent a0906068cf
commit 995006eabb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 2 deletions

View File

@ -13,7 +13,7 @@ subclass :class:`~transformers.Trainer` and override the methods you need (see :
By default a :class:`~transformers.Trainer` will use the following callbacks:
- :class:`~transformers.DefaultFlowCallback` which handles the default behavior for logging, saving and evaluation.
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProgressCallback` to display progress and print the
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
it's the second one).
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
@ -21,6 +21,8 @@ By default a :class:`~transformers.Trainer` will use the following callbacks:
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.
- :class:`~transformers.integrations.AzureMLCallback` if `azureml-sdk <https://pypi.org/project/azureml-sdk/>`__ is
installed.
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
@ -50,6 +52,7 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
.. autoclass:: transformers.integrations.MLflowCallback
:members: setup
.. autoclass:: transformers.integrations.AzureMLCallback
TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -61,6 +61,13 @@ except ImportError:
except ImportError:
_has_tensorboard = False
try:
from azureml.core.run import Run # noqa: F401
_has_azureml = True
except ImportError:
_has_azureml = False
try:
import mlflow # noqa: F401
@ -68,7 +75,6 @@ try:
except ImportError:
_has_mlflow = False
# No transformer imports above this point
from .file_utils import is_torch_tpu_available # noqa: E402
@ -97,6 +103,10 @@ def is_ray_available():
return _has_ray
def is_azureml_available():
return _has_azureml
def is_mlflow_available():
return _has_mlflow
@ -424,6 +434,27 @@ class CometCallback(TrainerCallback):
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
class AzureMLCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `AzureML
<https://pypi.org/project/azureml-sdk/>`__.
"""
def __init__(self, azureml_run=None):
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
self.azureml_run = azureml_run
def on_init_end(self, args, state, control, **kwargs):
if self.azureml_run is None and state.is_world_process_zero:
self.azureml_run = Run.get_context()
def on_log(self, args, state, control, logs=None, **kwargs):
if self.azureml_run:
for k, v in logs.items():
if isinstance(v, (int, float)):
self.azureml_run.log(k, v, description=k)
class MLflowCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.

View File

@ -31,6 +31,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .integrations import ( # isort: split
default_hp_search_backend,
hp_params,
is_azureml_available,
is_comet_available,
is_mlflow_available,
is_optuna_available,
@ -154,6 +155,11 @@ if is_optuna_available():
if is_ray_available():
from ray import tune
if is_azureml_available():
from .integrations import AzureMLCallback
DEFAULT_CALLBACKS.append(AzureMLCallback)
logger = logging.get_logger(__name__)