Update CometCallback to allow reusing of the running experiment (#31366)

* Update CometCallback to allow reusing of the running experiment

* Fixups

* Remove useless TODO

* Add checks for minimum version of the Comet SDK

* Fix documentation and links.

Also simplify how the Comet Experiment name is passed
This commit is contained in:
Boris Feld 2024-07-05 08:13:46 +02:00 committed by GitHub
parent d19b5a90c2
commit 9e599d1d94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 150 additions and 56 deletions

View File

@ -34,7 +34,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
- [`~integrations.TensorBoardCallback`] if tensorboard is accessible (either through PyTorch >= 1.4
or tensorboardX).
- [`~integrations.WandbCallback`] if [wandb](https://www.wandb.com/) is installed.
- [`~integrations.CometCallback`] if [comet_ml](https://www.comet.ml/site/) is installed.
- [`~integrations.CometCallback`] if [comet_ml](https://www.comet.com/site/) is installed.
- [`~integrations.MLflowCallback`] if [mlflow](https://www.mlflow.org/) is installed.
- [`~integrations.NeptuneCallback`] if [neptune](https://neptune.ai/) is installed.
- [`~integrations.AzureMLCallback`] if [azureml-sdk](https://pypi.org/project/azureml-sdk/) is

View File

@ -35,7 +35,7 @@ rendered properly in your Markdown viewer.
- [`~integrations.TensorBoardCallback`] (PyTorch >= 1.4 を介して) tensorboard にアクセスできる場合
またはテンソルボードX
- [`~integrations.WandbCallback`] [wandb](https://www.wandb.com/) がインストールされている場合。
- [`~integrations.CometCallback`] [comet_ml](https://www.comet.ml/site/) がインストールされている場合。
- [`~integrations.CometCallback`] [comet_ml](https://www.comet.com/site/) がインストールされている場合。
- [mlflow](https://www.mlflow.org/) がインストールされている場合は [`~integrations.MLflowCallback`]。
- [`~integrations.NeptuneCallback`] [neptune](https://neptune.ai/) がインストールされている場合。
- [`~integrations.AzureMLCallback`] [azureml-sdk](https://pypi.org/project/azureml-sdk/) の場合

View File

@ -28,7 +28,7 @@ Callbacks是“只读”的代码片段除了它们返回的[TrainerControl]
- [`PrinterCallback`] 或 [`ProgressCallback`],用于显示进度和打印日志(如果通过[`TrainingArguments`]停用tqdm则使用第一个函数否则使用第二个
- [`~integrations.TensorBoardCallback`]如果TensorBoard可访问通过PyTorch版本 >= 1.4 或者 tensorboardX
- [`~integrations.WandbCallback`],如果安装了[wandb](https://www.wandb.com/)。
- [`~integrations.CometCallback`],如果安装了[comet_ml](https://www.comet.ml/site/)。
- [`~integrations.CometCallback`],如果安装了[comet_ml](https://www.comet.com/site/)。
- [`~integrations.MLflowCallback`],如果安装了[mlflow](https://www.mlflow.org/)。
- [`~integrations.NeptuneCallback`],如果安装了[neptune](https://neptune.ai/)。
- [`~integrations.AzureMLCallback`],如果安装了[azureml-sdk](https://pypi.org/project/azureml-sdk/)。

View File

@ -200,7 +200,7 @@ You can easily log and monitor your runs code. The following are currently suppo
* [TensorBoard](https://www.tensorflow.org/tensorboard)
* [Weights & Biases](https://docs.wandb.ai/integrations/huggingface)
* [Comet ML](https://www.comet.ml/docs/python-sdk/huggingface/)
* [Comet ML](https://www.comet.com/docs/v2/integrations/ml-frameworks/transformers/)
* [Neptune](https://docs.neptune.ai/integrations-and-supported-tools/model-training/hugging-face)
* [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps)
* [DVCLive](https://dvc.org/doc/dvclive/ml-frameworks/huggingface)
@ -244,7 +244,7 @@ Additional configuration options are available through generic [wandb environmen
Refer to related [documentation & examples](https://docs.wandb.ai/integrations/huggingface).
### Comet.ml
### Comet
To use `comet_ml`, install the Python package with:

View File

@ -51,19 +51,25 @@ if is_torch_available():
import torch
# comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet:
try:
import comet_ml # noqa: F401
_MIN_COMET_VERSION = "3.43.2"
try:
_comet_version = importlib.metadata.version("comet_ml")
_is_comet_installed = True
if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
_has_comet = True
else:
if os.getenv("COMET_MODE", "").upper() != "DISABLED":
logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
_has_comet = False
except (ImportError, ValueError):
_has_comet = False
_is_comet_recent_enough = packaging.version.parse(_comet_version) >= packaging.version.parse(_MIN_COMET_VERSION)
# Check if the Comet API Key is set
import comet_ml
if comet_ml.config.get_config("comet.api_key") is not None:
_is_comet_configured = True
else:
_is_comet_configured = False
except (importlib.metadata.PackageNotFoundError, ImportError, ValueError, TypeError, AttributeError, KeyError):
_comet_version = None
_is_comet_installed = False
_is_comet_recent_enough = False
_is_comet_configured = False
_has_neptune = (
importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
@ -103,7 +109,36 @@ def is_clearml_available():
def is_comet_available():
return _has_comet
if os.getenv("COMET_MODE", "").upper() == "DISABLED":
logger.warning(
"Using the `COMET_MODE=DISABLED` environment variable is deprecated and will be removed in v5. Use the "
"--report_to flag to control the integrations used for logging result (for instance --report_to none)."
)
return False
if _is_comet_installed is False:
return False
if _is_comet_recent_enough is False:
logger.warning(
"comet_ml version %s is installed, but version %s or higher is required. "
"Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=%s'.",
_comet_version,
_MIN_COMET_VERSION,
_MIN_COMET_VERSION,
)
return False
if _is_comet_configured is False:
logger.warning(
"comet_ml is installed but the Comet API Key is not configured. "
"Please set the `COMET_API_KEY` environment variable to enable Comet logging. "
"Check out the documentation for other ways of configuring it: "
"https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#set-the-api-key"
)
return False
return True
def is_tensorboard_available():
@ -936,56 +971,109 @@ class WandbCallback(TrainerCallback):
class CometCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.com/site/).
"""
def __init__(self):
if not _has_comet:
raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
if _is_comet_installed is False or _is_comet_recent_enough is False:
raise RuntimeError(
f"CometCallback requires comet-ml>={_MIN_COMET_VERSION} to be installed. Run `pip install comet-ml>={_MIN_COMET_VERSION}`."
)
self._initialized = False
self._log_assets = False
self._experiment = None
def setup(self, args, state, model):
"""
Setup the optional Comet.ml integration.
Setup the optional Comet integration.
Environment:
- **COMET_MODE** (`str`, *optional*, defaults to `ONLINE`):
Whether to create an online, offline experiment or disable Comet logging. Can be `OFFLINE`, `ONLINE`, or
`DISABLED`.
- **COMET_MODE** (`str`, *optional*, default to `get_or_create`):
Control whether to create and log to a new Comet experiment or append to an existing experiment.
It accepts the following values:
* `get_or_create`: Decides automatically depending if
`COMET_EXPERIMENT_KEY` is set and whether an Experiment
with that key already exists or not.
* `create`: Always create a new Comet Experiment.
* `get`: Always try to append to an Existing Comet Experiment.
Requires `COMET_EXPERIMENT_KEY` to be set.
* `ONLINE`: **deprecated**, used to create an online
Experiment. Use `COMET_START_ONLINE=1` instead.
* `OFFLINE`: **deprecated**, used to created an offline
Experiment. Use `COMET_START_ONLINE=0` instead.
* `DISABLED`: **deprecated**, used to disable Comet logging.
Use the `--report_to` flag to control the integrations used
for logging result instead.
- **COMET_PROJECT_NAME** (`str`, *optional*):
Comet project name for experiments.
- **COMET_OFFLINE_DIRECTORY** (`str`, *optional*):
Folder to use for saving offline experiments when `COMET_MODE` is `OFFLINE`.
- **COMET_LOG_ASSETS** (`str`, *optional*, defaults to `TRUE`):
Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be `TRUE`, or
`FALSE`.
For a number of configurable items in the environment, see
[here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
[here](https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options).
"""
self._initialized = True
log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
if log_assets in {"TRUE", "1"}:
self._log_assets = True
if state.is_world_process_zero:
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
experiment = None
experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
if comet_mode == "ONLINE":
experiment = comet_ml.Experiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml online logging enabled")
elif comet_mode == "OFFLINE":
experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
experiment.log_other("Created from", "transformers")
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
if experiment is not None:
experiment._set_model_graph(model, framework="transformers")
experiment._log_parameters(args, prefix="args/", framework="transformers")
if hasattr(model, "config"):
experiment._log_parameters(model.config, prefix="config/", framework="transformers")
comet_old_mode = os.getenv("COMET_MODE")
mode = None
online = None
if comet_old_mode is not None:
comet_old_mode = comet_old_mode.lower()
if comet_old_mode == "online":
online = True
elif comet_old_mode == "offline":
online = False
elif comet_old_mode in ("get", "get_or_create", "create"):
mode = comet_old_mode
elif comet_old_mode:
logger.warning("Invalid COMET_MODE env value %r, Comet logging is disabled", comet_old_mode)
return
# For HPO, we always create a new experiment for each trial
if state.is_hyper_param_search:
if mode is not None:
logger.warning(
"Hyperparameter Search is enabled, forcing the creation of new experimetns, COMET_MODE value %r is ignored",
comet_old_mode,
)
mode = "create"
import comet_ml
# Do not use the default run_name as the experiment name
if args.run_name is not None and args.run_name != args.output_dir:
experiment_config = comet_ml.ExperimentConfig(name=args.run_name)
else:
experiment_config = comet_ml.ExperimentConfig()
self._experiment = comet_ml.start(online=online, mode=mode, experiment_config=experiment_config)
self._experiment.__internal_api__set_model_graph__(model, framework="transformers")
params = {"args": args.to_dict()}
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
params["config"] = model_config
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
params["peft_config"] = peft_config
self._experiment.__internal_api__log_parameters__(
params, framework="transformers", source="manual", flatten_nested=True
)
if state.is_hyper_param_search:
optimization_id = getattr(state, "trial_name", None)
optimization_params = getattr(state, "trial_params", None)
self._experiment.log_optimization(optimization_id=optimization_id, parameters=optimization_params)
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
@ -995,20 +1083,24 @@ class CometCallback(TrainerCallback):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
if self._experiment is not None:
self._experiment.__internal_api__log_metrics__(
logs, step=state.global_step, epoch=state.epoch, framework="transformers"
)
def on_train_end(self, args, state, control, **kwargs):
if self._initialized and state.is_world_process_zero:
experiment = comet_ml.config.get_global_experiment()
if experiment is not None:
if self._experiment is not None:
if self._log_assets is True:
logger.info("Logging checkpoints. This may take time.")
experiment.log_asset_folder(
self._experiment.log_asset_folder(
args.output_dir, recursive=True, log_file_name=True, step=state.global_step
)
experiment.end()
# We create one experiment per trial in HPO mode
if state.is_hyper_param_search:
self._experiment.clean()
self._initialized = False
class AzureMLCallback(TrainerCallback):

View File

@ -436,8 +436,9 @@ class TrainingArguments:
use the corresponding output (usually index 2) as the past state and feed it to the model at the next
training step under the keyword argument `mems`.
run_name (`str`, *optional*, defaults to `output_dir`):
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/) and
[mlflow](https://www.mlflow.org/) logging. If not specified, will be the same as `output_dir`.
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),
[mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will
be the same as `output_dir`.
disable_tqdm (`bool`, *optional*):
Whether or not to disable the tqdm progress bars and table of metrics produced by
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
@ -1149,7 +1150,8 @@ class TrainingArguments:
)
run_name: Optional[str] = field(
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
default=None,
metadata={"help": "An optional descriptor for the run. Notably used for wandb, mlflow and comet logging."},
)
disable_tqdm: Optional[bool] = field(
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}

View File

@ -160,7 +160,7 @@ class TFTrainingArguments(TrainingArguments):
Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
automatically detect from metadata.
run_name (`str`, *optional*):
A descriptor for the run. Notably used for wandb logging.
A descriptor for the run. Notably used for wandb, mlflow and comet logging.
xla (`bool`, *optional*):
Whether to activate the XLA compilation or not.
"""