mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 10:41:07 +06:00
[s2s] support early stopping based on loss, rather than rouge (#6927)
This commit is contained in:
parent
207ed8cb78
commit
e95d262f25
@ -75,21 +75,23 @@ class Seq2SeqLoggingCallback(pl.Callback):
|
|||||||
return self._write_logs(trainer, pl_module, "test")
|
return self._write_logs(trainer, pl_module, "test")
|
||||||
|
|
||||||
|
|
||||||
def get_checkpoint_callback(output_dir, metric, save_top_k=1):
|
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
||||||
"""Saves the best model by validation ROUGE2 score."""
|
"""Saves the best model by validation ROUGE2 score."""
|
||||||
if metric == "rouge2":
|
if metric == "rouge2":
|
||||||
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
||||||
elif metric == "bleu":
|
elif metric == "bleu":
|
||||||
exp = "{val_avg_bleu:.4f}-{step_count}"
|
exp = "{val_avg_bleu:.4f}-{step_count}"
|
||||||
|
elif metric == "loss":
|
||||||
|
exp = "{val_avg_loss:.4f}-{step_count}"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function."
|
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
filepath=os.path.join(output_dir, exp),
|
filepath=os.path.join(output_dir, exp),
|
||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}",
|
||||||
mode="max",
|
mode="min" if "loss" in metric else "max",
|
||||||
save_top_k=save_top_k,
|
save_top_k=save_top_k,
|
||||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||||
)
|
)
|
||||||
@ -98,8 +100,8 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1):
|
|||||||
|
|
||||||
def get_early_stopping_callback(metric, patience):
|
def get_early_stopping_callback(metric, patience):
|
||||||
return EarlyStopping(
|
return EarlyStopping(
|
||||||
monitor=f"val_{metric}",
|
monitor=f"val_{metric}", # does this need avg?
|
||||||
mode="max",
|
mode="min" if "loss" in metric else "max",
|
||||||
patience=patience,
|
patience=patience,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
@ -148,10 +148,10 @@ class SummarizationModule(BaseTransformer):
|
|||||||
lm_logits = outputs[0]
|
lm_logits = outputs[0]
|
||||||
if self.hparams.label_smoothing == 0:
|
if self.hparams.label_smoothing == 0:
|
||||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||||
|
|
||||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||||
loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||||
else:
|
else:
|
||||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||||
loss, nll_loss = label_smoothed_nll_loss(
|
loss, nll_loss = label_smoothed_nll_loss(
|
||||||
@ -178,15 +178,25 @@ class SummarizationModule(BaseTransformer):
|
|||||||
self.step_count += 1
|
self.step_count += 1
|
||||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||||
loss = losses["loss"]
|
loss = losses["loss"]
|
||||||
rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
|
generative_metrics = {
|
||||||
rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss)
|
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
||||||
rouges.update({k: v.item() for k, v in losses.items()})
|
}
|
||||||
losses.update(rouges)
|
metric_val = (
|
||||||
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[self.val_metric]
|
||||||
metrics["step_count"] = self.step_count
|
)
|
||||||
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
|
||||||
|
generative_metrics.update({k: v.item() for k, v in losses.items()})
|
||||||
|
losses.update(generative_metrics)
|
||||||
|
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
||||||
|
all_metrics["step_count"] = self.step_count
|
||||||
|
self.save_metrics(all_metrics, prefix) # writes to self.metrics_save_path
|
||||||
preds = flatten_list([x["preds"] for x in outputs])
|
preds = flatten_list([x["preds"] for x in outputs])
|
||||||
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor}
|
return {
|
||||||
|
"log": all_metrics,
|
||||||
|
"preds": preds,
|
||||||
|
f"{prefix}_loss": loss,
|
||||||
|
f"{prefix}_{self.val_metric}": metric_tensor,
|
||||||
|
}
|
||||||
|
|
||||||
def save_metrics(self, latest_metrics, type_path) -> None:
|
def save_metrics(self, latest_metrics, type_path) -> None:
|
||||||
self.metrics[type_path].append(latest_metrics)
|
self.metrics[type_path].append(latest_metrics)
|
||||||
@ -306,7 +316,9 @@ class SummarizationModule(BaseTransformer):
|
|||||||
parser.add_argument("--src_lang", type=str, default="", required=False)
|
parser.add_argument("--src_lang", type=str, default="", required=False)
|
||||||
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
||||||
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
||||||
parser.add_argument("--val_metric", type=str, default=None, required=False)
|
parser.add_argument(
|
||||||
|
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
||||||
|
)
|
||||||
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stopping_patience",
|
"--early_stopping_patience",
|
||||||
@ -366,14 +378,17 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||||
else:
|
else:
|
||||||
es_callback = False
|
es_callback = False
|
||||||
|
|
||||||
|
lower_is_better = args.val_metric == "loss"
|
||||||
trainer: pl.Trainer = generic_train(
|
trainer: pl.Trainer = generic_train(
|
||||||
model,
|
model,
|
||||||
args,
|
args,
|
||||||
logging_callback=Seq2SeqLoggingCallback(),
|
logging_callback=Seq2SeqLoggingCallback(),
|
||||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric, args.save_top_k),
|
checkpoint_callback=get_checkpoint_callback(
|
||||||
|
args.output_dir, model.val_metric, args.save_top_k, lower_is_better
|
||||||
|
),
|
||||||
early_stopping_callback=es_callback,
|
early_stopping_callback=es_callback,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
# TODO: early stopping callback seems messed up
|
|
||||||
)
|
)
|
||||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||||
if not args.do_predict:
|
if not args.do_predict:
|
||||||
|
@ -33,7 +33,7 @@ CUDA_AVAILABLE = torch.cuda.is_available()
|
|||||||
CHEAP_ARGS = {
|
CHEAP_ARGS = {
|
||||||
"label_smoothing": 0.2,
|
"label_smoothing": 0.2,
|
||||||
"eval_beams": 1,
|
"eval_beams": 1,
|
||||||
"val_metric": None,
|
"val_metric": "loss",
|
||||||
"save_top_k": 1,
|
"save_top_k": 1,
|
||||||
"adafactor": True,
|
"adafactor": True,
|
||||||
"early_stopping_patience": 2,
|
"early_stopping_patience": 2,
|
||||||
@ -262,9 +262,9 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
if not check_contents:
|
if not check_contents:
|
||||||
return model
|
return model
|
||||||
contents = os.listdir(output_dir)
|
contents = os.listdir(output_dir)
|
||||||
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
|
||||||
contents = {os.path.basename(p) for p in contents}
|
contents = {os.path.basename(p) for p in contents}
|
||||||
self.assertIn(ckpt_name, contents)
|
ckpt_files = [p for p in contents if p.endswith("ckpt")]
|
||||||
|
assert len(ckpt_files) > 0
|
||||||
|
|
||||||
self.assertIn("test_generations.txt", contents)
|
self.assertIn("test_generations.txt", contents)
|
||||||
self.assertIn("test_results.txt", contents)
|
self.assertIn("test_results.txt", contents)
|
||||||
|
Loading…
Reference in New Issue
Block a user