mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-17 03:28:22 +06:00

* Created using Colaboratory * [examples] reorganize files * remove run_tpu_glue.py as superseded by TPU support in Trainer * Bugfix: int, not tuple * move files around
185 lines
7.0 KiB
Python
185 lines
7.0 KiB
Python
import argparse
|
|
import glob
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from lightning_base import BaseTransformer, add_generic_args, generic_train, get_linear_schedule_with_warmup
|
|
|
|
|
|
try:
|
|
from .utils import SummarizationDataset
|
|
except ImportError:
|
|
from utils import SummarizationDataset
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SummarizationTrainer(BaseTransformer):
|
|
|
|
mode = "language-modeling"
|
|
|
|
def __init__(self, hparams):
|
|
super().__init__(hparams, num_labels=None, mode=self.mode)
|
|
self.dataset_kwargs: dict = dict(
|
|
data_dir=self.hparams.data_dir,
|
|
max_source_length=self.hparams.max_source_length,
|
|
max_target_length=self.hparams.max_target_length,
|
|
)
|
|
|
|
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
|
|
return self.model(
|
|
input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, lm_labels=lm_labels,
|
|
)
|
|
|
|
def _step(self, batch):
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"]
|
|
y_ids = y[:, :-1].contiguous()
|
|
lm_labels = y[:, 1:].clone()
|
|
lm_labels[y[:, 1:] == pad_token_id] = -100
|
|
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,)
|
|
|
|
loss = outputs[0]
|
|
|
|
return loss
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
loss = self._step(batch)
|
|
|
|
tensorboard_logs = {"train_loss": loss}
|
|
return {"loss": loss, "log": tensorboard_logs}
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
loss = self._step(batch)
|
|
return {"val_loss": loss}
|
|
|
|
def validation_end(self, outputs):
|
|
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
|
tensorboard_logs = {"val_loss": avg_loss}
|
|
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
|
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
|
|
generated_ids = self.model.generate(
|
|
input_ids=source_ids,
|
|
attention_mask=source_mask,
|
|
num_beams=1,
|
|
max_length=80,
|
|
repetition_penalty=2.5,
|
|
length_penalty=1.0,
|
|
early_stopping=True,
|
|
use_cache=True,
|
|
)
|
|
preds = [
|
|
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
for g in generated_ids
|
|
]
|
|
target = [self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) for t in y]
|
|
loss = self._step(batch)
|
|
|
|
return {"val_loss": loss, "preds": preds, "target": target}
|
|
|
|
def test_end(self, outputs):
|
|
return self.validation_end(outputs)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
output_test_predictions_file = os.path.join(self.hparams.output_dir, "test_predictions.txt")
|
|
output_test_targets_file = os.path.join(self.hparams.output_dir, "test_targets.txt")
|
|
# write predictions and targets for later rouge evaluation.
|
|
with open(output_test_predictions_file, "w+") as p_writer, open(output_test_targets_file, "w+") as t_writer:
|
|
for output_batch in outputs:
|
|
p_writer.writelines(s + "\n" for s in output_batch["preds"])
|
|
t_writer.writelines(s + "\n" for s in output_batch["target"])
|
|
p_writer.close()
|
|
t_writer.close()
|
|
|
|
return self.test_end(outputs)
|
|
|
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
|
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, **self.dataset_kwargs)
|
|
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle)
|
|
return dataloader
|
|
|
|
def train_dataloader(self) -> DataLoader:
|
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
|
t_total = (
|
|
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
|
|
// self.hparams.gradient_accumulation_steps
|
|
* float(self.hparams.num_train_epochs)
|
|
)
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
|
)
|
|
self.lr_scheduler = scheduler
|
|
return dataloader
|
|
|
|
def val_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
|
|
|
def test_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parser, root_dir):
|
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
|
# Add BART specific options
|
|
parser.add_argument(
|
|
"--max_source_length",
|
|
default=1024,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_target_length",
|
|
default=56,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The input data dir. Should contain the dataset files for the CNN/DM summarization task.",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main(args):
|
|
|
|
# If output_dir not provided, a folder will be generated in pwd
|
|
if not args.output_dir:
|
|
args.output_dir = os.path.join("./results", f"{args.task}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
|
os.makedirs(args.output_dir)
|
|
model = SummarizationTrainer(args)
|
|
trainer = generic_train(model, args)
|
|
|
|
# Optionally, predict on dev set and write to output_dir
|
|
if args.do_predict:
|
|
# See https://github.com/huggingface/transformers/issues/3159
|
|
# pl use this format to create a checkpoint:
|
|
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
|
|
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
|
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
|
model = model.load_from_checkpoint(checkpoints[-1])
|
|
trainer.test(model)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
add_generic_args(parser, os.getcwd())
|
|
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|