mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Add AzureML in integrations via dedicated callback (#8062)
* first attempt to add AzureML callbacks * func arg fix * var name fix, but still won't fix error... * fixing as in https://discuss.huggingface.co/t/how-to-integrate-an-azuremlcallback-for-logging-in-azure/1713/2 * Avoid lint check of azureml import * black compliance * Make isort happy * Fix point typo in docs * Add AzureML to Callbacks docs * Attempt to make sphinx happy * Format callback docs * Make documentation style happy * Make docs compliant to style Co-authored-by: Davide Fiocco <davide.fiocco@frontiersin.net>
This commit is contained in:
parent
a0906068cf
commit
995006eabb
@ -13,7 +13,7 @@ subclass :class:`~transformers.Trainer` and override the methods you need (see :
|
||||
By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||
|
||||
- :class:`~transformers.DefaultFlowCallback` which handles the default behavior for logging, saving and evaluation.
|
||||
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
|
||||
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProgressCallback` to display progress and print the
|
||||
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
|
||||
it's the second one).
|
||||
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
|
||||
@ -21,6 +21,8 @@ By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
||||
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
||||
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.
|
||||
- :class:`~transformers.integrations.AzureMLCallback` if `azureml-sdk <https://pypi.org/project/azureml-sdk/>`__ is
|
||||
installed.
|
||||
|
||||
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
||||
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
||||
@ -50,6 +52,7 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
|
||||
.. autoclass:: transformers.integrations.MLflowCallback
|
||||
:members: setup
|
||||
|
||||
.. autoclass:: transformers.integrations.AzureMLCallback
|
||||
|
||||
TrainerCallback
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -61,6 +61,13 @@ except ImportError:
|
||||
except ImportError:
|
||||
_has_tensorboard = False
|
||||
|
||||
try:
|
||||
from azureml.core.run import Run # noqa: F401
|
||||
|
||||
_has_azureml = True
|
||||
except ImportError:
|
||||
_has_azureml = False
|
||||
|
||||
try:
|
||||
import mlflow # noqa: F401
|
||||
|
||||
@ -68,7 +75,6 @@ try:
|
||||
except ImportError:
|
||||
_has_mlflow = False
|
||||
|
||||
|
||||
# No transformer imports above this point
|
||||
|
||||
from .file_utils import is_torch_tpu_available # noqa: E402
|
||||
@ -97,6 +103,10 @@ def is_ray_available():
|
||||
return _has_ray
|
||||
|
||||
|
||||
def is_azureml_available():
|
||||
return _has_azureml
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return _has_mlflow
|
||||
|
||||
@ -424,6 +434,27 @@ class CometCallback(TrainerCallback):
|
||||
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
|
||||
|
||||
|
||||
class AzureMLCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `AzureML
|
||||
<https://pypi.org/project/azureml-sdk/>`__.
|
||||
"""
|
||||
|
||||
def __init__(self, azureml_run=None):
|
||||
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
|
||||
self.azureml_run = azureml_run
|
||||
|
||||
def on_init_end(self, args, state, control, **kwargs):
|
||||
if self.azureml_run is None and state.is_world_process_zero:
|
||||
self.azureml_run = Run.get_context()
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if self.azureml_run:
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)):
|
||||
self.azureml_run.log(k, v, description=k)
|
||||
|
||||
|
||||
class MLflowCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
|
||||
|
@ -31,6 +31,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from .integrations import ( # isort: split
|
||||
default_hp_search_backend,
|
||||
hp_params,
|
||||
is_azureml_available,
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_optuna_available,
|
||||
@ -154,6 +155,11 @@ if is_optuna_available():
|
||||
if is_ray_available():
|
||||
from ray import tune
|
||||
|
||||
if is_azureml_available():
|
||||
from .integrations import AzureMLCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user