Adds dvclive callback (#27352)

* dvclive trainer callback

* style fixes

* dvclive link fixes
This commit is contained in:
Dave Berenbaum 2023-11-09 07:19:31 -05:00 committed by GitHub
parent c5d7754b11
commit 791ec370d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 5 deletions

View File

@ -44,6 +44,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
@ -88,6 +89,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
[[autodoc]] integrations.FlyteCallback
[[autodoc]] integrations.DVCLiveCallback
- setup
## TrainerCallback
[[autodoc]] TrainerCallback

View File

@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
- [`~integrations.ClearMLCallback`] [clearml](https://github.com/allegroai/clearml) がインストールされている場合。
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
@ -88,6 +89,9 @@ rendered properly in your Markdown viewer.
[[autodoc]] integrations.FlyteCallback
[[autodoc]] integrations.DVCLiveCallback
- setup
## TrainerCallback
[[autodoc]] TrainerCallback

View File

@ -201,6 +201,7 @@ You can easily log and monitor your runs code. The following are currently suppo
* [Comet ML](https://www.comet.ml/docs/python-sdk/huggingface/)
* [Neptune](https://docs.neptune.ai/integrations-and-supported-tools/model-training/hugging-face)
* [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps)
* [DVCLive](https://dvc.org/doc/dvclive/ml-frameworks/huggingface)
### Weights & Biases

View File

@ -108,6 +108,7 @@ _import_structure = {
"integrations": [
"is_clearml_available",
"is_comet_available",
"is_dvclive_available",
"is_neptune_available",
"is_optuna_available",
"is_ray_available",
@ -4300,6 +4301,7 @@ if TYPE_CHECKING:
from .integrations import (
is_clearml_available,
is_comet_available,
is_dvclive_available,
is_neptune_available,
is_optuna_available,
is_ray_available,

View File

@ -44,6 +44,7 @@ _import_structure = {
"CodeCarbonCallback",
"CometCallback",
"DagsHubCallback",
"DVCLiveCallback",
"FlyteCallback",
"MLflowCallback",
"NeptuneCallback",
@ -58,6 +59,7 @@ _import_structure = {
"is_codecarbon_available",
"is_comet_available",
"is_dagshub_available",
"is_dvclive_available",
"is_flyte_deck_standard_available",
"is_flytekit_available",
"is_mlflow_available",
@ -105,6 +107,7 @@ if TYPE_CHECKING:
CodeCarbonCallback,
CometCallback,
DagsHubCallback,
DVCLiveCallback,
FlyteCallback,
MLflowCallback,
NeptuneCallback,
@ -119,6 +122,7 @@ if TYPE_CHECKING:
is_codecarbon_available,
is_comet_available,
is_dagshub_available,
is_dvclive_available,
is_flyte_deck_standard_available,
is_flytekit_available,
is_mlflow_available,

View File

@ -26,7 +26,7 @@ import sys
import tempfile
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np
@ -152,6 +152,10 @@ def is_flyte_deck_standard_available():
return importlib.util.find_spec("flytekitplugins.deck") is not None
def is_dvclive_available():
return importlib.util.find_spec("dvclive") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
@ -541,6 +545,8 @@ def get_available_reporting_integrations():
integrations.append("comet_ml")
if is_dagshub_available():
integrations.append("dagshub")
if is_dvclive_available():
integrations.append("dvclive")
if is_mlflow_available():
integrations.append("mlflow")
if is_neptune_available():
@ -1605,6 +1611,98 @@ class FlyteCallback(TrainerCallback):
Deck("Log History", TableRenderer().to_html(log_history_df))
class DVCLiveCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
Args:
live (`dvclive.Live`, *optional*, defaults to `None`):
Optional Live instance. If None, a new instance will be created using **kwargs.
log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
the final checkpoint is logged at the end of training. If set to `"all"`, the entire
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
"""
def __init__(
self,
live: Optional[Any] = None,
log_model: Optional[Union[Literal["all"], bool]] = None,
**kwargs,
):
if not is_dvclive_available():
raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
from dvclive import Live
self._log_model = log_model
self._initialized = False
self.live = None
if isinstance(live, Live):
self.live = live
self._initialized = True
elif live is not None:
raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
def setup(self, args, state, model):
"""
Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
[here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
Environment:
- **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
*1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
"""
from dvclive import Live
self._initalized = True
if self._log_model is not None:
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL")
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
self._log_model = True
elif log_model_env.lower() == "all":
self._log_model = "all"
if state.is_world_process_zero:
if not self.live:
self.live = Live()
self.live.log_params(args.to_dict())
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
from dvclive.utils import standardize_metric_name
for key, value in logs.items():
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
self.live.next_step()
def on_save(self, args, state, control, **kwargs):
if self._log_model == "all" and self._initialized and state.is_world_process_zero:
self.live.log_artifact(args.output_dir)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
from transformers.trainer import Trainer
if self._log_model is True:
fake_trainer = Trainer(args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer"))
name = "best" if args.load_best_model_at_end else "last"
output_dir = os.path.join(args.output_dir, name)
fake_trainer.save_model(output_dir)
self.live.log_artifact(output_dir, name=name, type="model", copy=True)
self.live.end()
INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
@ -1616,6 +1714,7 @@ INTEGRATION_TO_CALLBACK = {
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
"flyte": FlyteCallback,
"dvclive": DVCLiveCallback,
}

View File

@ -509,7 +509,7 @@ class TrainingArguments:
instance of `Dataset`.
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
integrations.
ddp_find_unused_parameters (`bool`, *optional*):
@ -2391,9 +2391,9 @@ class TrainingArguments:
and lets the application set the level.
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
integrations.
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
`"none"` for no integrations.
first_step (`bool`, *optional*, defaults to `False`):
Whether to log and evaluate the first `global_step` or not.
nan_inf_filter (`bool`, *optional*, defaults to `True`):