mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
feat: Upgrade Weights & Biases callback (#30135)
* feat: upgrade wandb callback with new features * fix: ci issues with imports and run fixup
This commit is contained in:
parent
30b453206d
commit
4ab7a28216
@ -31,8 +31,17 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
import numpy as np
|
||||
import packaging.version
|
||||
|
||||
from .. import PreTrainedModel, TFPreTrainedModel
|
||||
from .. import __version__ as version
|
||||
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
|
||||
from ..utils import (
|
||||
PushToHubMixin,
|
||||
flatten_dict,
|
||||
is_datasets_available,
|
||||
is_pandas_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -69,6 +78,7 @@ if TYPE_CHECKING and _has_neptune:
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_has_neptune = False
|
||||
|
||||
from .. import modelcard # noqa: E402
|
||||
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
|
||||
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||
from ..training_args import ParallelMode # noqa: E402
|
||||
@ -663,6 +673,22 @@ class TensorBoardCallback(TrainerCallback):
|
||||
self.tb_writer = None
|
||||
|
||||
|
||||
def save_model_architecture_to_file(model: Any, output_dir: str):
|
||||
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
|
||||
if isinstance(model, PreTrainedModel):
|
||||
print(model, file=f)
|
||||
elif is_tf_available() and isinstance(model, TFPreTrainedModel):
|
||||
|
||||
def print_to_file(s):
|
||||
print(s, file=f)
|
||||
|
||||
model.summary(print_fn=print_to_file)
|
||||
elif is_torch_available() and (
|
||||
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
|
||||
):
|
||||
print(model, file=f)
|
||||
|
||||
|
||||
class WandbCallback(TrainerCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
|
||||
@ -728,6 +754,9 @@ class WandbCallback(TrainerCallback):
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
if hasattr(model, "peft_config") and model.peft_config is not None:
|
||||
peft_config = model.peft_config
|
||||
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
|
||||
trial_name = state.trial_name
|
||||
init_args = {}
|
||||
if trial_name is not None:
|
||||
@ -756,6 +785,51 @@ class WandbCallback(TrainerCallback):
|
||||
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
|
||||
self._wandb.run._label(code="transformers_trainer")
|
||||
|
||||
# add number of model parameters to wandb config
|
||||
if any(
|
||||
(
|
||||
isinstance(model, PreTrainedModel),
|
||||
isinstance(model, PushToHubMixin),
|
||||
(is_tf_available() and isinstance(model, TFPreTrainedModel)),
|
||||
(is_torch_available() and isinstance(model, torch.nn.Module)),
|
||||
)
|
||||
):
|
||||
self._wandb.config["model/num_parameters"] = model.num_parameters()
|
||||
|
||||
# log the initial model and architecture to an artifact
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model_name = (
|
||||
f"model-{self._wandb.run.id}"
|
||||
if (args.run_name is None or args.run_name == args.output_dir)
|
||||
else f"model-{self._wandb.run.name}"
|
||||
)
|
||||
model_artifact = self._wandb.Artifact(
|
||||
name=model_name,
|
||||
type="model",
|
||||
metadata={
|
||||
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
|
||||
"num_parameters": self._wandb.config.get("model/num_parameters"),
|
||||
"initial_model": True,
|
||||
},
|
||||
)
|
||||
model.save_pretrained(temp_dir)
|
||||
# add the architecture to a separate text file
|
||||
save_model_architecture_to_file(model, temp_dir)
|
||||
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with model_artifact.new_file(f.name, mode="wb") as fa:
|
||||
fa.write(f.read_bytes())
|
||||
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
|
||||
|
||||
badge_markdown = (
|
||||
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
|
||||
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
|
||||
f'0" height="32"/>]({self._wandb.run.get_url()})'
|
||||
)
|
||||
|
||||
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
@ -786,20 +860,25 @@ class WandbCallback(TrainerCallback):
|
||||
else {
|
||||
f"eval/{args.metric_for_best_model}": state.best_metric,
|
||||
"train/total_floss": state.total_flos,
|
||||
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
|
||||
}
|
||||
)
|
||||
metadata["final_model"] = True
|
||||
logger.info("Logging model artifacts. ...")
|
||||
model_name = (
|
||||
f"model-{self._wandb.run.id}"
|
||||
if (args.run_name is None or args.run_name == args.output_dir)
|
||||
else f"model-{self._wandb.run.name}"
|
||||
)
|
||||
# add the model architecture to a separate text file
|
||||
save_model_architecture_to_file(model, temp_dir)
|
||||
|
||||
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with artifact.new_file(f.name, mode="wb") as fa:
|
||||
fa.write(f.read_bytes())
|
||||
self._wandb.run.log_artifact(artifact)
|
||||
self._wandb.run.log_artifact(artifact, aliases=["final_model"])
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
single_value_scalars = [
|
||||
@ -829,18 +908,30 @@ class WandbCallback(TrainerCallback):
|
||||
for k, v in dict(self._wandb.summary).items()
|
||||
if isinstance(v, numbers.Number) and not k.startswith("_")
|
||||
}
|
||||
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
|
||||
|
||||
ckpt_dir = f"checkpoint-{state.global_step}"
|
||||
artifact_path = os.path.join(args.output_dir, ckpt_dir)
|
||||
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
|
||||
checkpoint_name = (
|
||||
f"checkpoint-{self._wandb.run.id}"
|
||||
f"model-{self._wandb.run.id}"
|
||||
if (args.run_name is None or args.run_name == args.output_dir)
|
||||
else f"checkpoint-{self._wandb.run.name}"
|
||||
else f"model-{self._wandb.run.name}"
|
||||
)
|
||||
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
|
||||
artifact.add_dir(artifact_path)
|
||||
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
|
||||
self._wandb.log_artifact(
|
||||
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
|
||||
)
|
||||
|
||||
def on_predict(self, args, state, control, metrics, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
if not self._initialized:
|
||||
self.setup(args, state, **kwargs)
|
||||
if state.is_world_process_zero:
|
||||
metrics = rewrite_logs(metrics)
|
||||
self._wandb.log(metrics)
|
||||
|
||||
|
||||
class CometCallback(TrainerCallback):
|
||||
|
Loading…
Reference in New Issue
Block a user