mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Adds dvclive callback (#27352)
* dvclive trainer callback * style fixes * dvclive link fixes
This commit is contained in:
parent
c5d7754b11
commit
791ec370d1
@ -44,6 +44,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
|
||||
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
|
||||
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
||||
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
||||
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
|
||||
|
||||
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
|
||||
|
||||
@ -88,6 +89,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
|
||||
|
||||
[[autodoc]] integrations.FlyteCallback
|
||||
|
||||
[[autodoc]] integrations.DVCLiveCallback
|
||||
- setup
|
||||
|
||||
## TrainerCallback
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
|
||||
- [`~integrations.ClearMLCallback`] [clearml](https://github.com/allegroai/clearml) がインストールされている場合。
|
||||
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
|
||||
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
|
||||
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
|
||||
|
||||
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
|
||||
|
||||
@ -88,6 +89,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
[[autodoc]] integrations.FlyteCallback
|
||||
|
||||
[[autodoc]] integrations.DVCLiveCallback
|
||||
- setup
|
||||
|
||||
## TrainerCallback
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
@ -201,6 +201,7 @@ You can easily log and monitor your runs code. The following are currently suppo
|
||||
* [Comet ML](https://www.comet.ml/docs/python-sdk/huggingface/)
|
||||
* [Neptune](https://docs.neptune.ai/integrations-and-supported-tools/model-training/hugging-face)
|
||||
* [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps)
|
||||
* [DVCLive](https://dvc.org/doc/dvclive/ml-frameworks/huggingface)
|
||||
|
||||
### Weights & Biases
|
||||
|
||||
|
@ -108,6 +108,7 @@ _import_structure = {
|
||||
"integrations": [
|
||||
"is_clearml_available",
|
||||
"is_comet_available",
|
||||
"is_dvclive_available",
|
||||
"is_neptune_available",
|
||||
"is_optuna_available",
|
||||
"is_ray_available",
|
||||
@ -4300,6 +4301,7 @@ if TYPE_CHECKING:
|
||||
from .integrations import (
|
||||
is_clearml_available,
|
||||
is_comet_available,
|
||||
is_dvclive_available,
|
||||
is_neptune_available,
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
|
@ -44,6 +44,7 @@ _import_structure = {
|
||||
"CodeCarbonCallback",
|
||||
"CometCallback",
|
||||
"DagsHubCallback",
|
||||
"DVCLiveCallback",
|
||||
"FlyteCallback",
|
||||
"MLflowCallback",
|
||||
"NeptuneCallback",
|
||||
@ -58,6 +59,7 @@ _import_structure = {
|
||||
"is_codecarbon_available",
|
||||
"is_comet_available",
|
||||
"is_dagshub_available",
|
||||
"is_dvclive_available",
|
||||
"is_flyte_deck_standard_available",
|
||||
"is_flytekit_available",
|
||||
"is_mlflow_available",
|
||||
@ -105,6 +107,7 @@ if TYPE_CHECKING:
|
||||
CodeCarbonCallback,
|
||||
CometCallback,
|
||||
DagsHubCallback,
|
||||
DVCLiveCallback,
|
||||
FlyteCallback,
|
||||
MLflowCallback,
|
||||
NeptuneCallback,
|
||||
@ -119,6 +122,7 @@ if TYPE_CHECKING:
|
||||
is_codecarbon_available,
|
||||
is_comet_available,
|
||||
is_dagshub_available,
|
||||
is_dvclive_available,
|
||||
is_flyte_deck_standard_available,
|
||||
is_flytekit_available,
|
||||
is_mlflow_available,
|
||||
|
@ -26,7 +26,7 @@ import sys
|
||||
import tempfile
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -152,6 +152,10 @@ def is_flyte_deck_standard_available():
|
||||
return importlib.util.find_spec("flytekitplugins.deck") is not None
|
||||
|
||||
|
||||
def is_dvclive_available():
|
||||
return importlib.util.find_spec("dvclive") is not None
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
@ -541,6 +545,8 @@ def get_available_reporting_integrations():
|
||||
integrations.append("comet_ml")
|
||||
if is_dagshub_available():
|
||||
integrations.append("dagshub")
|
||||
if is_dvclive_available():
|
||||
integrations.append("dvclive")
|
||||
if is_mlflow_available():
|
||||
integrations.append("mlflow")
|
||||
if is_neptune_available():
|
||||
@ -1605,6 +1611,98 @@ class FlyteCallback(TrainerCallback):
|
||||
Deck("Log History", TableRenderer().to_html(log_history_df))
|
||||
|
||||
|
||||
class DVCLiveCallback(TrainerCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
|
||||
|
||||
Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
|
||||
those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
|
||||
|
||||
Args:
|
||||
live (`dvclive.Live`, *optional*, defaults to `None`):
|
||||
Optional Live instance. If None, a new instance will be created using **kwargs.
|
||||
log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
|
||||
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
|
||||
the final checkpoint is logged at the end of training. If set to `"all"`, the entire
|
||||
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
live: Optional[Any] = None,
|
||||
log_model: Optional[Union[Literal["all"], bool]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if not is_dvclive_available():
|
||||
raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
|
||||
from dvclive import Live
|
||||
|
||||
self._log_model = log_model
|
||||
|
||||
self._initialized = False
|
||||
self.live = None
|
||||
if isinstance(live, Live):
|
||||
self.live = live
|
||||
self._initialized = True
|
||||
elif live is not None:
|
||||
raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
|
||||
|
||||
def setup(self, args, state, model):
|
||||
"""
|
||||
Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
|
||||
[here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
|
||||
|
||||
Environment:
|
||||
- **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
|
||||
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
|
||||
*1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
|
||||
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
|
||||
"""
|
||||
from dvclive import Live
|
||||
|
||||
self._initalized = True
|
||||
if self._log_model is not None:
|
||||
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL")
|
||||
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
|
||||
self._log_model = True
|
||||
elif log_model_env.lower() == "all":
|
||||
self._log_model = "all"
|
||||
if state.is_world_process_zero:
|
||||
if not self.live:
|
||||
self.live = Live()
|
||||
self.live.log_params(args.to_dict())
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
from dvclive.utils import standardize_metric_name
|
||||
|
||||
for key, value in logs.items():
|
||||
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
|
||||
self.live.next_step()
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
if self._log_model == "all" and self._initialized and state.is_world_process_zero:
|
||||
self.live.log_artifact(args.output_dir)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if self._initialized and state.is_world_process_zero:
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
if self._log_model is True:
|
||||
fake_trainer = Trainer(args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer"))
|
||||
name = "best" if args.load_best_model_at_end else "last"
|
||||
output_dir = os.path.join(args.output_dir, name)
|
||||
fake_trainer.save_model(output_dir)
|
||||
self.live.log_artifact(output_dir, name=name, type="model", copy=True)
|
||||
self.live.end()
|
||||
|
||||
|
||||
INTEGRATION_TO_CALLBACK = {
|
||||
"azure_ml": AzureMLCallback,
|
||||
"comet_ml": CometCallback,
|
||||
@ -1616,6 +1714,7 @@ INTEGRATION_TO_CALLBACK = {
|
||||
"clearml": ClearMLCallback,
|
||||
"dagshub": DagsHubCallback,
|
||||
"flyte": FlyteCallback,
|
||||
"dvclive": DVCLiveCallback,
|
||||
}
|
||||
|
||||
|
||||
|
@ -509,7 +509,7 @@ class TrainingArguments:
|
||||
instance of `Dataset`.
|
||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
||||
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
||||
integrations.
|
||||
ddp_find_unused_parameters (`bool`, *optional*):
|
||||
@ -2391,9 +2391,9 @@ class TrainingArguments:
|
||||
and lets the application set the level.
|
||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
||||
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
||||
integrations.
|
||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
|
||||
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
|
||||
`"none"` for no integrations.
|
||||
first_step (`bool`, *optional*, defaults to `False`):
|
||||
Whether to log and evaluate the first `global_step` or not.
|
||||
nan_inf_filter (`bool`, *optional*, defaults to `True`):
|
||||
|
Loading…
Reference in New Issue
Block a user