mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adds a FlyteCallback (#23759)
* initial flyte callback * lint * logs should still be saved to Flyte even if pandas isn't install (unlikely) * cr - flyte team * add docs for Flytecallback * fix doc string - cr sgugger * Apply suggestions from code review cr - sgugger fix doc strings Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
867316670a
commit
62ba64b90a
@ -39,6 +39,7 @@ By default a [`Trainer`] will use the following callbacks:
|
||||
installed.
|
||||
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
|
||||
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
||||
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
||||
|
||||
The main class that implements callbacks is [`TrainerCallback`]. It gets the
|
||||
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
|
||||
@ -79,6 +80,8 @@ Here is the list of the available [`TrainerCallback`] in the library:
|
||||
|
||||
[[autodoc]] integrations.DagsHubCallback
|
||||
|
||||
[[autodoc]] integrations.FlyteCallback
|
||||
|
||||
## TrainerCallback
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
@ -30,7 +30,7 @@ from typing import TYPE_CHECKING, Dict, Optional
|
||||
import numpy as np
|
||||
|
||||
from . import __version__ as version
|
||||
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
|
||||
from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
|
||||
from .utils.versions import importlib_metadata
|
||||
|
||||
|
||||
@ -146,6 +146,16 @@ def is_codecarbon_available():
|
||||
return importlib.util.find_spec("codecarbon") is not None
|
||||
|
||||
|
||||
def is_flytekit_available():
|
||||
return importlib.util.find_spec("flytekit") is not None
|
||||
|
||||
|
||||
def is_flyte_deck_standard_available():
|
||||
if not is_flytekit_available():
|
||||
return False
|
||||
return importlib.util.find_spec("flytekitplugins.deck") is not None
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
@ -1537,6 +1547,69 @@ class ClearMLCallback(TrainerCallback):
|
||||
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)
|
||||
|
||||
|
||||
class FlyteCallback(TrainerCallback):
|
||||
"""A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
|
||||
NOTE: This callback only works within a Flyte task.
|
||||
|
||||
Args:
|
||||
save_log_history (`bool`, *optional*, defaults to `True`):
|
||||
When set to True, the training logs are saved as a Flyte Deck.
|
||||
|
||||
sync_checkpoints (`bool`, *optional*, defaults to `True`):
|
||||
When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
|
||||
interruption.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Note: This example skips over some setup steps for brevity.
|
||||
from flytekit import current_context, task
|
||||
|
||||
|
||||
@task
|
||||
def train_hf_transformer():
|
||||
cp = current_context().checkpoint
|
||||
trainer = Trainer(..., callbacks=[FlyteCallback()])
|
||||
output = trainer.train(resume_from_checkpoint=cp.restore())
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
|
||||
super().__init__()
|
||||
if not is_flytekit_available():
|
||||
raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
|
||||
|
||||
if not is_flyte_deck_standard_available() or not is_pandas_available():
|
||||
logger.warning(
|
||||
"Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
|
||||
"Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
|
||||
)
|
||||
save_log_history = False
|
||||
|
||||
from flytekit import current_context
|
||||
|
||||
self.cp = current_context().checkpoint
|
||||
self.save_log_history = save_log_history
|
||||
self.sync_checkpoints = sync_checkpoints
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
if self.sync_checkpoints and state.is_world_process_zero:
|
||||
ckpt_dir = f"checkpoint-{state.global_step}"
|
||||
artifact_path = os.path.join(args.output_dir, ckpt_dir)
|
||||
|
||||
logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
|
||||
self.cp.save(artifact_path)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if self.save_log_history:
|
||||
import pandas as pd
|
||||
from flytekit import Deck
|
||||
from flytekitplugins.deck.renderer import TableRenderer
|
||||
|
||||
log_history_df = pd.DataFrame(state.log_history)
|
||||
Deck("Log History", TableRenderer().to_html(log_history_df))
|
||||
|
||||
|
||||
INTEGRATION_TO_CALLBACK = {
|
||||
"azure_ml": AzureMLCallback,
|
||||
"comet_ml": CometCallback,
|
||||
@ -1547,6 +1620,7 @@ INTEGRATION_TO_CALLBACK = {
|
||||
"codecarbon": CodeCarbonCallback,
|
||||
"clearml": ClearMLCallback,
|
||||
"dagshub": DagsHubCallback,
|
||||
"flyte": FlyteCallback,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user