feat: Upgrade Weights & Biases callback (#30135)

* feat: upgrade wandb callback with new features

* fix: ci issues with imports and run fixup
This commit is contained in:
Bharat Ramanathan 2024-04-19 15:33:32 +05:30 committed by GitHub
parent 30b453206d
commit 4ab7a28216
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,8 +31,17 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
import numpy as np
import packaging.version
from .. import PreTrainedModel, TFPreTrainedModel
from .. import __version__ as version
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
from ..utils import (
PushToHubMixin,
flatten_dict,
is_datasets_available,
is_pandas_available,
is_tf_available,
is_torch_available,
logging,
)
logger = logging.get_logger(__name__)
@ -69,6 +78,7 @@ if TYPE_CHECKING and _has_neptune:
except importlib.metadata.PackageNotFoundError:
_has_neptune = False
from .. import modelcard # noqa: E402
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
from ..training_args import ParallelMode # noqa: E402
@ -663,6 +673,22 @@ class TensorBoardCallback(TrainerCallback):
self.tb_writer = None
def save_model_architecture_to_file(model: Any, output_dir: str):
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
if isinstance(model, PreTrainedModel):
print(model, file=f)
elif is_tf_available() and isinstance(model, TFPreTrainedModel):
def print_to_file(s):
print(s, file=f)
model.summary(print_fn=print_to_file)
elif is_torch_available() and (
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
):
print(model, file=f)
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
@ -728,6 +754,9 @@ class WandbCallback(TrainerCallback):
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
if hasattr(model, "peft_config") and model.peft_config is not None:
peft_config = model.peft_config
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
@ -756,6 +785,51 @@ class WandbCallback(TrainerCallback):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")
# add number of model parameters to wandb config
if any(
(
isinstance(model, PreTrainedModel),
isinstance(model, PushToHubMixin),
(is_tf_available() and isinstance(model, TFPreTrainedModel)),
(is_torch_available() and isinstance(model, torch.nn.Module)),
)
):
self._wandb.config["model/num_parameters"] = model.num_parameters()
# log the initial model and architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
model.save_pretrained(temp_dir)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
return
@ -786,20 +860,25 @@ class WandbCallback(TrainerCallback):
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
}
)
metadata["final_model"] = True
logger.info("Logging model artifacts. ...")
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
# add the model architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact)
self._wandb.run.log_artifact(artifact, aliases=["final_model"])
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
single_value_scalars = [
@ -829,18 +908,30 @@ class WandbCallback(TrainerCallback):
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
ckpt_dir = f"checkpoint-{state.global_step}"
artifact_path = os.path.join(args.output_dir, ckpt_dir)
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
checkpoint_name = (
f"checkpoint-{self._wandb.run.id}"
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"checkpoint-{self._wandb.run.name}"
else f"model-{self._wandb.run.name}"
)
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
artifact.add_dir(artifact_path)
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
self._wandb.log_artifact(
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
)
def on_predict(self, args, state, control, metrics, **kwargs):
if self._wandb is None:
return
if not self._initialized:
self.setup(args, state, **kwargs)
if state.is_world_process_zero:
metrics = rewrite_logs(metrics)
self._wandb.log(metrics)
class CometCallback(TrainerCallback):