mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Upgrade examples to pl=0.8.1(#5146)
This commit is contained in:
parent
06b60c8b05
commit
f5c2a122e3
@ -8,6 +8,7 @@ from typing import Any, Dict
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
|
||||
model=None,
|
||||
**config_kwargs
|
||||
):
|
||||
"Initialize a model."
|
||||
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
self.hparams = hparams # TODO: move to self.save_hyperparameters()
|
||||
self.step_count = 0
|
||||
self.tfmr_ckpts = {}
|
||||
self.output_dir = Path(self.hparams.output_dir)
|
||||
@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
|
||||
)
|
||||
else:
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
if model is None:
|
||||
self.model_type = MODEL_MODES[mode]
|
||||
self.model = self.model_type.from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
else:
|
||||
self.model_type = None
|
||||
self.model = model
|
||||
|
||||
def load_hf_checkpoint(self, *args, **kwargs):
|
||||
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
||||
|
||||
def is_logger(self):
|
||||
return self.trainer.proc_rank <= 0
|
||||
|
||||
def configure_optimizers(self):
|
||||
"Prepare optimizer and schedule (linear warmup and decay)"
|
||||
|
||||
model = self.model
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
|
||||
self.opt = optimizer
|
||||
return [optimizer]
|
||||
|
||||
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
|
||||
if self.trainer.use_tpu:
|
||||
xm.optimizer_step(optimizer)
|
||||
else:
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
def get_tqdm_dict(self):
|
||||
avg_loss = getattr(self.trainer, "avg_loss", 0.0)
|
||||
tqdm_dict = {"loss": "{:.3f}".format(avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}
|
||||
return tqdm_dict
|
||||
|
||||
def test_step(self, batch, batch_nb):
|
||||
return self.validation_step(batch, batch_nb)
|
||||
|
||||
def test_end(self, outputs):
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_end(outputs)
|
||||
|
||||
def train_dataloader(self):
|
||||
@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--warmup_steps", default=500, type=int, help="Linear warmup over warmup_steps.")
|
||||
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
||||
)
|
||||
@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule):
|
||||
|
||||
|
||||
class LoggingCallback(pl.Callback):
|
||||
@rank_zero_only
|
||||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
logger.info("***** Validation results *****")
|
||||
if pl_module.is_logger():
|
||||
metrics = trainer.callback_metrics
|
||||
# Log results
|
||||
rank_zero_info("***** Validation results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log results
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
@rank_zero_only
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
logger.info("***** Test results *****")
|
||||
metrics = trainer.callback_metrics
|
||||
# Log and save results to file
|
||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
||||
logger.info("***** Test results *****")
|
||||
|
||||
if pl_module.is_logger():
|
||||
metrics = trainer.callback_metrics
|
||||
|
||||
# Log and save results to file
|
||||
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
||||
with open(output_test_results_file, "w") as writer:
|
||||
for key in sorted(metrics):
|
||||
if key not in ["log", "progress_bar"]:
|
||||
logger.info("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
||||
|
||||
|
||||
def add_generic_args(parser, root_dir) -> None:
|
||||
|
@ -5,7 +5,7 @@ psutil
|
||||
sacrebleu
|
||||
rouge-score
|
||||
tensorflow_datasets
|
||||
pytorch-lightning==0.7.6
|
||||
pytorch-lightning==0.8.1
|
||||
matplotlib
|
||||
git-python==1.0.3
|
||||
faiss
|
||||
|
@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqLoggingCallback(pl.Callback):
|
||||
@rank_zero_only
|
||||
def _write_logs(
|
||||
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
||||
) -> None:
|
||||
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
||||
if not pl_module.is_logger():
|
||||
return
|
||||
metrics = trainer.callback_metrics
|
||||
trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]})
|
||||
# Log results
|
||||
|
@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
|
||||
|
||||
class T5SummarizationDistiller(SummarizationDistiller):
|
||||
def pre_init(self, hparams):
|
||||
raise NotImplementedError("T5 Distillation does not work yet")
|
||||
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
|
||||
n_layer = hparams.student_decoder_layers
|
||||
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this
|
||||
|
@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
|
||||
if self.hparams.freeze_encoder:
|
||||
freeze_params(self.model.model.encoder) # TODO: this will break for t5
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = 4 if self.hparams.gpus <= 1 else None # passing num_workers breaks lightning for multigpu
|
||||
self.num_workers = hparams.num_workers
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
|
||||
def validation_step(self, batch, batch_idx) -> Dict:
|
||||
return self._generative_step(batch)
|
||||
|
||||
def validation_end(self, outputs, prefix="val") -> Dict:
|
||||
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
||||
self.step_count += 1
|
||||
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
||||
loss = losses["loss"]
|
||||
@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
|
||||
self.metrics[prefix].append(metrics)
|
||||
pickle_save(self.metrics, self.metrics_save_path)
|
||||
|
||||
def _generative_step(self, batch):
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
# TODO(SS): task specific params
|
||||
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(input_ids=source_ids, attention_mask=source_mask, use_cache=True,)
|
||||
gen_time = time.time() - t0
|
||||
gen_time = time.time() - t0 / source_ids.shape[0]
|
||||
preds = self.ids_to_clean_text(generated_ids)
|
||||
target = self.ids_to_clean_text(y)
|
||||
loss_tensors = self._step(batch)
|
||||
@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self._generative_step(batch)
|
||||
|
||||
def test_end(self, outputs):
|
||||
return self.validation_end(outputs, prefix="test")
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
||||
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
|
||||
# write predictions and targets for later rouge evaluation.
|
||||
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
|
||||
for output_batch in outputs:
|
||||
p_writer.writelines(s + "\n" for s in output_batch["preds"])
|
||||
t_writer.writelines(s + "\n" for s in output_batch["target"])
|
||||
p_writer.close()
|
||||
t_writer.close()
|
||||
|
||||
return self.test_end(outputs)
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.validation_end(outputs, "val")
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
|
||||
logger=logger,
|
||||
# TODO: early stopping callback seems messed up
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
|
@ -7,6 +7,5 @@ python distillation.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 \
|
||||
--val_check_interval 0.1 \
|
||||
$@
|
||||
|
@ -26,6 +26,7 @@ def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE, fp16=False,
|
||||
) -> None:
|
||||
fout = Path(out_file).open("w", encoding="utf-8")
|
||||
model_name = str(model_name)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
||||
if fp16:
|
||||
model = model.half()
|
||||
|
@ -24,6 +24,7 @@ logger = logging.getLogger()
|
||||
FP16_EVER = False
|
||||
CHEAP_ARGS = {
|
||||
"logger": "default",
|
||||
"num_workers": 2,
|
||||
"alpha_hid": 0,
|
||||
"freeze_embeds": True,
|
||||
"enc_only": False,
|
||||
@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
|
||||
f.write("\n".join(articles))
|
||||
|
||||
|
||||
BDIR = Path("~/transformers_fork/examples/summarization/bart/").absolute()
|
||||
MSG = "T5 is broken at the moment"
|
||||
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
||||
|
||||
|
||||
def make_test_data_dir():
|
||||
@ -92,7 +94,6 @@ def make_test_data_dir():
|
||||
return tmp_dir
|
||||
|
||||
|
||||
@unittest.skip("These wont' pass until hidden_states kwarg is merged.")
|
||||
class TestSummarizationDistiller(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
sortish_sampler=False,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_fp16(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=3.0,
|
||||
freeze_encoder=True,
|
||||
gpus=1,
|
||||
fp16=FP16_EVER,
|
||||
fp16_opt_level="O1",
|
||||
fp16=FP16_EVER,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_t5_eval_fp16(self):
|
||||
def test_bdc_t5_train(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1,
|
||||
gpus=1 if torch.cuda.is_available() else 0,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=False,
|
||||
do_predict=True,
|
||||
tokenizer_name=None,
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), "skipping fp16 test")
|
||||
def test_bdc_t5_train_fp16(self):
|
||||
updates = dict(
|
||||
fp16=FP16_EVER,
|
||||
gpus=1,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
model_name_or_path=T5_TINY,
|
||||
do_train=True,
|
||||
do_predict=True,
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name=T5_TINY,
|
||||
no_teacher=True,
|
||||
alpha_hid=2.0,
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_checkpointing(self):
|
||||
|
||||
updates = dict(
|
||||
student_encoder_layers=2,
|
||||
student_decoder_layers=1,
|
||||
@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
||||
|
||||
def test_bdc_t5(self):
|
||||
updates = dict(
|
||||
student_encoder_layers=1,
|
||||
student_decoder_layers=1,
|
||||
alpha_hid=2.0,
|
||||
teacher="patrickvonplaten/t5-tiny-random",
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
)
|
||||
self._bart_distiller_cli(updates)
|
||||
|
||||
def test_bdc_t5_eval(self):
|
||||
updates = dict(
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
do_train=False,
|
||||
do_predict=True,
|
||||
tokenizer_name="patrickvonplaten/t5-tiny-random",
|
||||
no_teacher=True,
|
||||
)
|
||||
self._bart_distiller_cli(updates, check_contents=False)
|
||||
|
||||
def _bart_distiller_cli(self, updates, check_contents=True):
|
||||
default_updates = dict(
|
||||
model_type="bart",
|
||||
train_batch_size=1,
|
||||
eval_batch_size=2,
|
||||
num_train_epochs=2,
|
||||
@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
self.assertIn(ckpt_name, contents)
|
||||
self.assertIn("metrics.pkl", contents)
|
||||
self.assertIn("test_generations.txt", contents)
|
||||
self.assertIn("val_generations_1.txt", contents)
|
||||
self.assertIn("val_1_results.txt", contents)
|
||||
self.assertIn("val_generations_00001.txt", contents)
|
||||
self.assertIn("val_results_00001.txt", contents)
|
||||
self.assertIn("test_results.txt", contents)
|
||||
# self.assertEqual(len(contents), 15)
|
||||
|
||||
metrics = pickle_load(Path(output_dir) / "metrics.pkl")
|
||||
import pandas as pd
|
||||
|
||||
val_df = pd.DataFrame(metrics["val"])
|
||||
train_df = pd.DataFrame(metrics["train"])
|
||||
test_df = pd.DataFrame(metrics["test"])
|
||||
desired_n_evals = args_d["num_train_epochs"] * 2 + 1
|
||||
self.assertEqual(val_df.shape[0], desired_n_evals) #
|
||||
self.assertEqual(test_df.shape[1], val_df.shape[1])
|
||||
self.assertEqual(train_df.shape[0], 0)
|
||||
desired_n_evals = int(args_d["num_train_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
||||
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
||||
self.assertEqual(len(metrics["train"]), 0) # doesn't get logged here
|
||||
return model
|
||||
|
||||
|
||||
@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
tokenizer_name=None, # "patrickvonplaten/t5-tiny-random",
|
||||
model_name_or_path=T5_TINY,
|
||||
tokenizer_name=None, # T5_TINY,
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
gpus=0,
|
||||
|
@ -45,8 +45,10 @@ def encode_file(
|
||||
max_length=max_length,
|
||||
pad_to_max_length=pad_to_max_length,
|
||||
add_prefix_space=True,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
assert tokenized.input_ids.shape[1] == max_length
|
||||
examples.append(tokenized)
|
||||
torch.save(lmap(dict, examples), cache_path.open("wb"))
|
||||
return examples
|
||||
|
@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
|
||||
|
||||
return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
|
||||
|
||||
def _eval_end(self, outputs):
|
||||
def _eval_end(self, outputs) -> tuple:
|
||||
val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
|
||||
preds = np.concatenate([x["pred"] for x in outputs], axis=0)
|
||||
|
||||
@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
|
||||
logs = ret["log"]
|
||||
return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
# updating to test_epoch_end instead of deprecated test_end
|
||||
def test_epoch_end(self, outputs) -> dict:
|
||||
ret, predictions, targets = self._eval_end(outputs)
|
||||
|
||||
# Converting to the dic required by pl
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
|
||||
# pytorch_lightning/trainer/logging.py#L139
|
||||
logs = ret["log"]
|
||||
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
|
||||
return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
# Add NER specific options
|
||||
BaseTransformer.add_model_specific_args(parser, root_dir)
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
|
@ -205,7 +205,7 @@ class AutoTokenizer:
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
||||
if "bert-base-japanese" in str(pretrained_model_name_or_path):
|
||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
|
||||
use_fast = kwargs.pop("use_fast", False)
|
||||
|
Loading…
Reference in New Issue
Block a user