mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
feat(wandb): save model as artifact (#8119)
* feat(wandb): log artifacts * fix: typo * feat(wandb): ensure name is allowed * feat(wandb): log artifact * feat(wandb): saving logic * style: improve formatting * fix: unrelated typo * feat: use a fake trainer * fix: simplify * feat(wandb): log model files as artifact * style: fix style * docs(wandb): correct description * feat: unpack model + allow env Truethy values * feat: TrainerCallback can access tokenizer * style: fix style * feat(wandb): log more interesting metadata * feat: unpack tokenizer * feat(wandb): metadata with load_best_model_at_end * feat(wandb): more robust metadata * style(wandb): fix formatting
This commit is contained in:
parent
143289dcf7
commit
30fa0b780f
@ -15,8 +15,13 @@
|
||||
Integrations with other Python libraries.
|
||||
"""
|
||||
import math
|
||||
import numbers
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .file_utils import ENV_VARS_TRUE_VALUES
|
||||
from .trainer_utils import EvaluationStrategy
|
||||
from .utils import logging
|
||||
|
||||
@ -369,6 +374,8 @@ class WandbCallback(TrainerCallback):
|
||||
<https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
||||
|
||||
Environment:
|
||||
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to log model as artifact at the end of training.
|
||||
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
|
||||
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
|
||||
logging or :obj:`"all"` to log gradients and parameters.
|
||||
@ -407,12 +414,44 @@ class WandbCallback(TrainerCallback):
|
||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
|
||||
|
||||
# log outputs
|
||||
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
hp_search = state.is_hyper_param_search
|
||||
if not self._initialized or hp_search:
|
||||
print(args.run_name)
|
||||
self.setup(args, state, model, reinit=hp_search, **kwargs)
|
||||
|
||||
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
|
||||
# commit last step
|
||||
wandb.log({})
|
||||
if self._log_model and self._initialized and state.is_world_process_zero:
|
||||
from .trainer import Trainer
|
||||
|
||||
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
fake_trainer.save_model(temp_dir)
|
||||
# use run name and ensure it's a valid Artifact name
|
||||
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
|
||||
metadata = (
|
||||
{
|
||||
k: v
|
||||
for k, v in dict(wandb.summary).items()
|
||||
if isinstance(v, numbers.Number) and not k.startswith("_")
|
||||
}
|
||||
if not args.load_best_model_at_end
|
||||
else {
|
||||
f"eval/{args.metric_for_best_model}": state.best_metric,
|
||||
"train/total_floss": state.total_flos,
|
||||
}
|
||||
)
|
||||
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with artifact.new_file(f.name, mode="wb") as fa:
|
||||
fa.write(f.read_bytes())
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, reinit=False)
|
||||
|
@ -261,7 +261,9 @@ class Trainer:
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
|
||||
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||
|
@ -168,6 +168,8 @@ class TrainerCallback:
|
||||
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
|
||||
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
|
||||
The model being trained.
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for encoding the data.
|
||||
optimizer (:obj:`torch.optim.Optimizer`):
|
||||
The optimizer used for the training steps.
|
||||
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
|
||||
@ -274,11 +276,12 @@ class TrainerCallback:
|
||||
class CallbackHandler(TrainerCallback):
|
||||
""" Internal class that just calls the list of callbacks in order. """
|
||||
|
||||
def __init__(self, callbacks, model, optimizer, lr_scheduler):
|
||||
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
|
||||
self.callbacks = []
|
||||
for cb in callbacks:
|
||||
self.add_callback(cb)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_dataloader = None
|
||||
@ -376,6 +379,7 @@ class CallbackHandler(TrainerCallback):
|
||||
state,
|
||||
control,
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
train_dataloader=self.train_dataloader,
|
||||
|
Loading…
Reference in New Issue
Block a user