don't log base model architecture in wandb if log model is false (#32143)

* don't log base model architecture in wandb is log model is false

* Update src/transformers/integrations/integration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* convert log model setting into an enum

* fix formatting

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
João Nadkarni 2024-07-26 09:38:59 +02:00 committed by GitHub
parent c46edfb823
commit 1c7ebf1d6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@ import shutil
import sys import sys
import tempfile import tempfile
from dataclasses import asdict, fields from dataclasses import asdict, fields
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
@ -726,6 +727,35 @@ def save_model_architecture_to_file(model: Any, output_dir: str):
print(model, file=f) print(model, file=f)
class WandbLogModel(str, Enum):
"""Enum of possible log model values in W&B."""
CHECKPOINT = "checkpoint"
END = "end"
FALSE = "false"
@property
def is_enabled(self) -> bool:
"""Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
@classmethod
def _missing_(cls, value: Any) -> "WandbLogModel":
if not isinstance(value, str):
raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
if value.upper() in ENV_VARS_TRUE_VALUES:
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
return WandbLogModel.END
logger.warning(
f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
)
return WandbLogModel.FALSE
class WandbCallback(TrainerCallback): class WandbCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
@ -740,16 +770,7 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb self._wandb = wandb
self._initialized = False self._initialized = False
# log model self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
def setup(self, args, state, model, **kwargs): def setup(self, args, state, model, **kwargs):
""" """
@ -834,37 +855,38 @@ class WandbCallback(TrainerCallback):
logger.info("Could not log the number of model parameters in Weights & Biases.") logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model architecture to an artifact # log the initial model architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir: if self._log_model.is_enabled:
model_name = ( with tempfile.TemporaryDirectory() as temp_dir:
f"model-{self._wandb.run.id}" model_name = (
if (args.run_name is None or args.run_name == args.output_dir) f"model-{self._wandb.run.id}"
else f"model-{self._wandb.run.name}" if (args.run_name is None or args.run_name == args.output_dir)
) else f"model-{self._wandb.run.name}"
model_artifact = self._wandb.Artifact( )
name=model_name, model_artifact = self._wandb.Artifact(
type="model", name=model_name,
metadata={ type="model",
"model_config": model.config.to_dict() if hasattr(model, "config") else None, metadata={
"num_parameters": self._wandb.config.get("model/num_parameters"), "model_config": model.config.to_dict() if hasattr(model, "config") else None,
"initial_model": True, "num_parameters": self._wandb.config.get("model/num_parameters"),
}, "initial_model": True,
) },
# add the architecture to a separate text file )
save_model_architecture_to_file(model, temp_dir) # add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"): for f in Path(temp_dir).glob("*"):
if f.is_file(): if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa: with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes()) fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"]) self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = ( badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge' f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20' f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})' f'0" height="32"/>]({self._wandb.run.get_url()})'
) )
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None: if self._wandb is None:
@ -880,7 +902,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None: if self._wandb is None:
return return
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero: if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
@ -938,7 +960,7 @@ class WandbCallback(TrainerCallback):
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step}) self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs): def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero: if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = { checkpoint_metadata = {
k: v k: v
for k, v in dict(self._wandb.summary).items() for k, v in dict(self._wandb.summary).items()