[pl_examples] default warmup steps=0 (#5316)

This commit is contained in:
Sam Shleifer 2020-06-26 15:03:41 -04:00 committed by GitHub
parent bf0d12c220
commit 5543b30aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 14 additions and 13 deletions

View File

@ -122,12 +122,9 @@ class BaseTransformer(pl.LightningModule):
else: else:
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
self.lr_scheduler.step() self.lr_scheduler.step() # By default, PL will only step every epoch.
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
def get_tqdm_dict(self): self.logger.log_metrics(lrs)
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb): def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb) return self.validation_step(batch, batch_nb)
@ -202,7 +199,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument( parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform." "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."

View File

@ -64,6 +64,7 @@ The following command should work on a 16GB GPU:
Tips: Tips:
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100. - 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below) - try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
- `fp16_opt_level=O1` (the default works best). - `fp16_opt_level=O1` (the default works best).
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries. - If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.

View File

@ -3,6 +3,7 @@ import glob
import logging import logging
import os import os
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -216,6 +217,8 @@ class SummarizationModule(BaseTransformer):
scheduler = get_linear_schedule_with_warmup( scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
) )
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler self.lr_scheduler = scheduler
return dataloader return dataloader

View File

@ -193,13 +193,12 @@ class TestSummarizationDistiller(unittest.TestCase):
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) @pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval_bart(model): def test_run_eval_bart(model):
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo" input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
assert not output_file_name.exists() assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(tmp, articles) _dump_articles(input_file_name, articles)
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
assert Path(output_file_name).exists() assert Path(output_file_name).exists()

View File

@ -16,9 +16,9 @@ python finetune.py \
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \ --freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
--max_target_length 142 --val_max_target_length=142 \ --max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \ --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
--data_dir $CNN_DIR \
--model_name_or_path sshleifer/student_cnn_12_6 \ --model_name_or_path sshleifer/student_cnn_12_6 \
--tokenizer_name facebook/bart-large \ --tokenizer_name facebook/bart-large \
--warmup_steps 500 \
--output_dir distilbart-cnn-12-6 \ --output_dir distilbart-cnn-12-6 \
$@ $@

View File

@ -16,5 +16,6 @@ python distillation.py \
--alpha_hid=3. --length_penalty=0.5 \ --alpha_hid=3. --length_penalty=0.5 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \ --train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
--tokenizer_name facebook/bart-large \ --tokenizer_name facebook/bart-large \
--warmup_steps 500 \
--output_dir distilbart_xsum_12_6 \ --output_dir distilbart_xsum_12_6 \
$@ $@