mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
[pl_examples] default warmup steps=0 (#5316)
This commit is contained in:
parent
bf0d12c220
commit
5543b30aa6
@ -122,12 +122,9 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
else:
|
else:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step() # By default, PL will only step every epoch.
|
||||||
|
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
|
||||||
def get_tqdm_dict(self):
|
self.logger.log_metrics(lrs)
|
||||||
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
|
|
||||||
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
|
|
||||||
return tqdm_dict
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_nb):
|
def test_step(self, batch, batch_nb):
|
||||||
return self.validation_step(batch, batch_nb)
|
return self.validation_step(batch, batch_nb)
|
||||||
@ -202,7 +199,7 @@ class BaseTransformer(pl.LightningModule):
|
|||||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||||
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||||
|
@ -64,6 +64,7 @@ The following command should work on a 16GB GPU:
|
|||||||
|
|
||||||
Tips:
|
Tips:
|
||||||
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
- 1 epoch at batch size 1 for bart-large takes 24 hours and requires 13GB GPU RAM with fp16 on an NVIDIA-V100.
|
||||||
|
- since you need to run from `examples/seq2seq`, and likely need to modify code, it is easiest to fork, then clone transformers and run `pip install -e .` before you get started.
|
||||||
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
- try `bart-base`, `--freeze_encoder` or `--freeze_embeds` for faster training/larger batch size. (3hr/epoch with bs=8, see the "xsum_shared_task" command below)
|
||||||
- `fp16_opt_level=O1` (the default works best).
|
- `fp16_opt_level=O1` (the default works best).
|
||||||
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
- If you are finetuning on your own dataset, start from `distilbart-cnn-12-6` if you want long summaries and `distilbart-xsum-12-6` if you want short summaries.
|
||||||
|
@ -3,6 +3,7 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@ -216,6 +217,8 @@ class SummarizationModule(BaseTransformer):
|
|||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||||
)
|
)
|
||||||
|
if max(scheduler.get_last_lr()) > 0:
|
||||||
|
warnings.warn("All learning rates are 0")
|
||||||
self.lr_scheduler = scheduler
|
self.lr_scheduler = scheduler
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
|
@ -193,13 +193,12 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||||||
|
|
||||||
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
||||||
def test_run_eval_bart(model):
|
def test_run_eval_bart(model):
|
||||||
tmp = Path(tempfile.gettempdir()) / "utest_generations_bart_sum.hypo"
|
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
||||||
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
||||||
output_file_name = Path(tempfile.gettempdir()) / "utest_output_bart_sum.hypo"
|
|
||||||
assert not output_file_name.exists()
|
assert not output_file_name.exists()
|
||||||
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
||||||
_dump_articles(tmp, articles)
|
_dump_articles(input_file_name, articles)
|
||||||
testargs = ["run_eval.py", str(tmp), str(output_file_name), model] # TODO: test score_path
|
testargs = ["run_eval.py", str(input_file_name), str(output_file_name), model] # TODO: test score_path
|
||||||
with patch.object(sys, "argv", testargs):
|
with patch.object(sys, "argv", testargs):
|
||||||
run_generate()
|
run_generate()
|
||||||
assert Path(output_file_name).exists()
|
assert Path(output_file_name).exists()
|
||||||
|
@ -16,9 +16,9 @@ python finetune.py \
|
|||||||
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
|
--freeze_encoder --freeze_embeds --data_dir $CNN_DIR \
|
||||||
--max_target_length 142 --val_max_target_length=142 \
|
--max_target_length 142 --val_max_target_length=142 \
|
||||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS \
|
||||||
--data_dir $CNN_DIR \
|
|
||||||
--model_name_or_path sshleifer/student_cnn_12_6 \
|
--model_name_or_path sshleifer/student_cnn_12_6 \
|
||||||
--tokenizer_name facebook/bart-large \
|
--tokenizer_name facebook/bart-large \
|
||||||
|
--warmup_steps 500 \
|
||||||
--output_dir distilbart-cnn-12-6 \
|
--output_dir distilbart-cnn-12-6 \
|
||||||
$@
|
$@
|
||||||
|
|
||||||
|
@ -16,5 +16,6 @@ python distillation.py \
|
|||||||
--alpha_hid=3. --length_penalty=0.5 \
|
--alpha_hid=3. --length_penalty=0.5 \
|
||||||
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \
|
||||||
--tokenizer_name facebook/bart-large \
|
--tokenizer_name facebook/bart-large \
|
||||||
|
--warmup_steps 500 \
|
||||||
--output_dir distilbart_xsum_12_6 \
|
--output_dir distilbart_xsum_12_6 \
|
||||||
$@
|
$@
|
||||||
|
Loading…
Reference in New Issue
Block a user