mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[pl_examples] default warmup steps=0 (#5316)
This commit is contained in:
parent
bf0d12c220
commit
5543b30aa6
@ -122,12 +122,9 @@ class BaseTransformer(pl.LightningModule):
|
||||
else:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def get_tqdm_dict(self):
|
||||
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
|
||||
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
|
||||
return tqdm_dict
|
||||
self.lr_scheduler.step() # By default, PL will only step every epoch.
|
||||
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
||||
self.logger.log_metrics(lrs)
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
@ -202,7 +199,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||
|
@ -64,6 +64,7 @@ The following command should work on a 16GB GPU:
|
||||
|
||||
Tips:
|
||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
|
||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||
- `fp16_opt_level=O1` (the default works best).
|
||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||
|
@ -3,6 +3,7 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
@ -216,6 +217,8 @@ class SummarizationModule(BaseTransformer):
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if max(scheduler.get_last_lr()) > 0:
|
||||
warnings.warn("All learning rates are 0")
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
|
@ -193,13 +193,12 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||
def test_run_eval_bart(model):
|
||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
||||
|
||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
||||
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||
assert not output_file_name.exists()
|
||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||
_dump_articles(tmp, articles)
|
||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
|
||||
_dump_articles(input_file_name, articles)
|
||||
testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_generate()
|
||||
assert Path(output_file_name).exists()
|
||||
|
@ -16,9 +16,9 @@ python finetune.py \
|
||||
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
|
||||
--max_target_length 142 --val_max_target_length=142 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||
--data_dir $CNN_DIR \
|
||||
--model_name_or_path sshleifer/student_cnn_12_6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--warmup_steps 500 \
|
||||
--output_dir distilbart-cnn-12-6 \
|
||||
$@
|
||||
|
||||
|
@ -16,5 +16,6 @@ python distillation.py \
|
||||
--alpha_hid=3. --length_penalty=0.5 \
|
||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
||||
--tokenizer_name facebook/bart-large \
|
||||
--warmup_steps 500 \
|
||||
--output_dir distilbart_xsum_12_6 \
|
||||
$@
|
||||
|
Loading…
Reference in New Issue
Block a user