Upgrade examples to pl=0.8.1(#5146)

This commit is contained in:
Sam Shleifer 2020-06-22 20:40:10 -04:00 committed by GitHub
parent 06b60c8b05
commit f5c2a122e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 53 additions and 150 deletions

View File

@ -8,6 +8,7 @@ from typing import Any, Dict
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from transformers import (
AdamW,
@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
model=None,
**config_kwargs
):
"Initialize a model."
"""Initialize a model, tokenizer and config."""
super().__init__()
self.hparams = hparams
self.hparams = hparams # TODO: move to self.save_hyperparameters()
self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir)
@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
)
else:
self.tokenizer: PreTrainedTokenizer = tokenizer
self.model_type = MODEL_MODES[mode]
if model is None:
self.model_type = MODEL_MODES[mode]
self.model = self.model_type.from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
cache_dir=cache_dir,
)
else:
self.model_type = None
self.model = model
def load_hf_checkpoint(self, *args, **kwargs):
self.model = self.model_type.from_pretrained(*args, **kwargs)
def is_logger(self):
return self.trainer.proc_rank <= 0
def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)"
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
self.opt = optimizer
return [optimizer]
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
if self.trainer.use_tpu:
xm.optimizer_step(optimizer)
else:
optimizer.step()
optimizer.zero_grad()
self.lr_scheduler.step()
def get_tqdm_dict(self):
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
return tqdm_dict
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_end(self, outputs):
def test_epoch_end(self, outputs):
return self.validation_end(outputs)
def train_dataloader(self):
@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
parser.add_argument(
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
)
@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule):
class LoggingCallback(pl.Callback):
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Validation results *****")
if pl_module.is_logger():
metrics = trainer.callback_metrics
# Log results
rank_zero_info("***** Validation results *****")
metrics = trainer.callback_metrics
# Log results
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
@rank_zero_only
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****")
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key])))
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
logger.info("***** Test results *****")
if pl_module.is_logger():
metrics = trainer.callback_metrics
# Log and save results to file
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
logger.info("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
writer.write("{} = {}\n".format(key, str(metrics[key])))
def add_generic_args(parser, root_dir) -> None:

View File

@ -5,7 +5,7 @@ psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.7.6
pytorch-lightning==0.8.1
matplotlib
git-python==1.0.3
faiss

View File

@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback):
@rank_zero_only
def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
) -> None:
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
if not pl_module.is_logger():
return
metrics = trainer.callback_metrics
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
# Log results

View File

@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
class T5SummarizationDistiller(SummarizationDistiller):
def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet")
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this

View File

@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
if self.hparams.freeze_encoder:
freeze_params(self.model.model.encoder) # TODO: this will break for t5
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
self.num_workers = hparams.num_workers
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
def validation_step(self, batch, batch_idx) -> Dict:
return self._generative_step(batch)
def validation_end(self, outputs, prefix="val") -> Dict:
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
self.step_count += 1
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
loss = losses["loss"]
@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
self.metrics[prefix].append(metrics)
pickle_save(self.metrics, self.metrics_save_path)
def _generative_step(self, batch):
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
# TODO(SS): task specific params
t0 = time.time()
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
gen_time = time.time() - t0
gen_time = time.time() - t0 / source_ids.shape[0]
preds = self.ids_to_clean_text(generated_ids)
target = self.ids_to_clean_text(y)
loss_tensors = self._step(batch)
@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
def test_step(self, batch, batch_idx):
return self._generative_step(batch)
def test_end(self, outputs):
return self.validation_end(outputs, prefix="test")
def test_epoch_end(self, outputs):
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
# write predictions and targets for later rouge evaluation.
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
for output_batch in outputs:
p_writer.writelines(s + "\n" for s in output_batch["preds"])
t_writer.writelines(s + "\n" for s in output_batch["target"])
p_writer.close()
t_writer.close()
return self.test_end(outputs)
def validation_epoch_end(self, outputs):
self.validation_end(outputs, "val")
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
logger=logger,
# TODO: early stopping callback seems messed up
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model

View File

@ -7,6 +7,5 @@ python distillation.py \
--learning_rate=3e-4 \
--do_train \
--do_predict \
--fp16 \
--val_check_interval 0.1 \
$@

View File

@ -26,6 +26,7 @@ def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()

View File

@ -24,6 +24,7 @@ logger = logging.getLogger()
FP16_EVER = False
CHEAP_ARGS = {
"logger": "default",
"num_workers": 2,
"alpha_hid": 0,
"freeze_embeds": True,
"enc_only": False,
@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
f.write("\n".join(articles))
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
MSG = "T5 is broken at the moment"
T5_TINY = "patrickvonplaten/t5-tiny-random"
def make_test_data_dir():
@ -92,7 +94,6 @@ def make_test_data_dir():
return tmp_dir
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
class TestSummarizationDistiller(unittest.TestCase):
@classmethod
def setUpClass(cls):
@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
freeze_encoder=True,
gpus=2,
sortish_sampler=False,
)
self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_fp16(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
alpha_hid=3.0,
freeze_encoder=True,
gpus=1,
fp16=FP16_EVER,
fp16_opt_level="O1",
fp16=FP16_EVER,
)
self._bart_distiller_cli(updates)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_t5_eval_fp16(self):
def test_bdc_t5_train(self):
updates = dict(
fp16=FP16_EVER,
gpus=1,
gpus=1 if torch.cuda.is_available() else 0,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=False,
do_predict=True,
tokenizer_name=None,
no_teacher=True,
)
self._bart_distiller_cli(updates, check_contents=False)
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
def test_bdc_t5_train_fp16(self):
updates = dict(
fp16=FP16_EVER,
gpus=1,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
model_name_or_path=T5_TINY,
do_train=True,
do_predict=True,
tokenizer_name="patrickvonplaten/t5-tiny-random",
tokenizer_name=T5_TINY,
no_teacher=True,
alpha_hid=2.0,
)
self._bart_distiller_cli(updates)
@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
self._bart_distiller_cli(updates)
def test_bdc_checkpointing(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
def test_bdc_t5(self):
updates = dict(
student_encoder_layers=1,
student_decoder_layers=1,
alpha_hid=2.0,
teacher="patrickvonplaten/t5-tiny-random",
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer_name="patrickvonplaten/t5-tiny-random",
)
self._bart_distiller_cli(updates)
def test_bdc_t5_eval(self):
updates = dict(
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
do_train=False,
do_predict=True,
tokenizer_name="patrickvonplaten/t5-tiny-random",
no_teacher=True,
)
self._bart_distiller_cli(updates, check_contents=False)
def _bart_distiller_cli(self, updates, check_contents=True):
default_updates = dict(
model_type="bart",
train_batch_size=1,
eval_batch_size=2,
num_train_epochs=2,
@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertIn(ckpt_name, contents)
self.assertIn("metrics.pkl", contents)
self.assertIn("test_generations.txt", contents)
self.assertIn("val_generations_1.txt", contents)
self.assertIn("val_1_results.txt", contents)
self.assertIn("val_generations_00001.txt", contents)
self.assertIn("val_results_00001.txt", contents)
self.assertIn("test_results.txt", contents)
# self.assertEqual(len(contents), 15)
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
import pandas as pd
val_df = pd.DataFrame(metrics["val"])
train_df = pd.DataFrame(metrics["train"])
test_df = pd.DataFrame(metrics["test"])
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
self.assertEqual(val_df.shape[0], desired_n_evals) #
self.assertEqual(test_df.shape[1], val_df.shape[1])
self.assertEqual(train_df.shape[0], 0)
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
self.assertEqual(len(metrics["val"]), desired_n_evals)
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
return model
@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
output_dir = tempfile.mkdtemp(prefix="output_")
args_d.update(
data_dir=tmp_dir,
model_type="t5",
model_name_or_path="patrickvonplaten/t5-tiny-random",
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
model_name_or_path=T5_TINY,
tokenizer_name=None, # T5_TINY,
train_batch_size=2,
eval_batch_size=2,
gpus=0,

View File

@ -45,8 +45,10 @@ def encode_file(
max_length=max_length,
pad_to_max_length=pad_to_max_length,
add_prefix_space=True,
truncation=True,
return_tensors=return_tensors,
)
assert tokenized.input_ids.shape[1] == max_length
examples.append(tokenized)
torch.save(lmap(dict, examples), cache_path.open("wb"))
return examples

View File

@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
def _eval_end(self, outputs):
def _eval_end(self, outputs) -> tuple:
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
logs = ret["log"]
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
def test_epoch_end(self, outputs):
# updating to test_epoch_end instead of deprecated test_end
def test_epoch_end(self, outputs) -> dict:
ret, predictions, targets = self._eval_end(outputs)
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs = ret["log"]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
@staticmethod
def add_model_specific_args(parser, root_dir):
# Add NER specific options
BaseTransformer.add_model_specific_args(parser, root_dir)
parser.add_argument(
"--max_seq_length",

View File

@ -205,7 +205,7 @@ class AutoTokenizer:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if "bert-base-japanese" in pretrained_model_name_or_path:
if "bert-base-japanese" in str(pretrained_model_name_or_path):
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
use_fast = kwargs.pop("use_fast", False)