Adds a FlyteCallback (#23759)

* initial flyte callback

* lint

* logs should still be saved to Flyte even if pandas isn't install (unlikely)

* cr - flyte team

* add docs for Flytecallback

* fix doc string - cr sgugger

* Apply suggestions from code review

cr - sgugger fix doc strings

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
peridotml 2023-05-30 07:08:07 -07:00 committed by GitHub
parent 867316670a
commit 62ba64b90a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 78 additions and 1 deletions

View File

@ -39,6 +39,7 @@ By default a [`Trainer`] will use the following callbacks:
installed.
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
The main class that implements callbacks is [`TrainerCallback`]. It gets the
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
@ -79,6 +80,8 @@ Here is the list of the available [`TrainerCallback`] in the library:
[[autodoc]] integrations.DagsHubCallback
[[autodoc]] integrations.FlyteCallback
## TrainerCallback
[[autodoc]] TrainerCallback

View File

@ -30,7 +30,7 @@ from typing import TYPE_CHECKING, Dict, Optional
import numpy as np
from . import __version__ as version
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
from .utils.versions import importlib_metadata
@ -146,6 +146,16 @@ def is_codecarbon_available():
return importlib.util.find_spec("codecarbon") is not None
def is_flytekit_available():
return importlib.util.find_spec("flytekit") is not None
def is_flyte_deck_standard_available():
if not is_flytekit_available():
return False
return importlib.util.find_spec("flytekitplugins.deck") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
@ -1537,6 +1547,69 @@ class ClearMLCallback(TrainerCallback):
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)
class FlyteCallback(TrainerCallback):
"""A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
NOTE: This callback only works within a Flyte task.
Args:
save_log_history (`bool`, *optional*, defaults to `True`):
When set to True, the training logs are saved as a Flyte Deck.
sync_checkpoints (`bool`, *optional*, defaults to `True`):
When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
interruption.
Example:
```python
# Note: This example skips over some setup steps for brevity.
from flytekit import current_context, task
@task
def train_hf_transformer():
cp = current_context().checkpoint
trainer = Trainer(..., callbacks=[FlyteCallback()])
output = trainer.train(resume_from_checkpoint=cp.restore())
```
"""
def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
super().__init__()
if not is_flytekit_available():
raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
if not is_flyte_deck_standard_available() or not is_pandas_available():
logger.warning(
"Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
"Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
)
save_log_history = False
from flytekit import current_context
self.cp = current_context().checkpoint
self.save_log_history = save_log_history
self.sync_checkpoints = sync_checkpoints
def on_save(self, args, state, control, **kwargs):
if self.sync_checkpoints and state.is_world_process_zero:
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
self.cp.save(artifact_path)
def on_train_end(self, args, state, control, **kwargs):
if self.save_log_history:
import pandas as pd
from flytekit import Deck
from flytekitplugins.deck.renderer import TableRenderer
log_history_df = pd.DataFrame(state.log_history)
Deck("Log History", TableRenderer().to_html(log_history_df))
INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
@ -1547,6 +1620,7 @@ INTEGRATION_TO_CALLBACK = {
"codecarbon": CodeCarbonCallback,
"clearml": ClearMLCallback,
"dagshub": DagsHubCallback,
"flyte": FlyteCallback,
}