[examples/s2s] clean up finetune_trainer (#7509)

This commit is contained in:
Suraj Patil 2020-10-01 21:49:29 +05:30 committed by GitHub
parent bd2621583b
commit 72d363d979
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 107 additions and 105 deletions

View File

@ -2,37 +2,29 @@ import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from typing import Optional
from seq2seq_trainer import Seq2SeqTrainer
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartTokenizer,
EvalPrediction,
HfArgumentParser,
MBartTokenizer,
T5Tokenizer,
TrainingArguments,
set_seed,
)
from transformers.modeling_bart import shift_tokens_right
from transformers.trainer_utils import EvaluationStrategy
from utils import (
LegacySeq2SeqDataset,
Seq2SeqDataCollator,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
build_compute_metrics_fn,
freeze_embeds,
freeze_params,
lmap,
save_json,
trim_batch,
use_task_specific_params,
write_txt_file,
)
@ -41,66 +33,6 @@ from utils import (
logger = logging.getLogger(__name__)
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
@dataclass
class Seq2SeqTrainingArguments(TrainingArguments):
"""
@ -271,34 +203,6 @@ def main():
), "mBart requires --tgt_lang and --src_lang"
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
if model_args.freeze_embeds:
freeze_embeds(model)
if model_args.freeze_encoder:
@ -349,13 +253,17 @@ def main():
)
# Initialize our Trainer
compute_metrics_fn = (
build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
)
trainer = Seq2SeqTrainer(
model=model,
config=config,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
compute_metrics=compute_metrics_fn,
data_args=data_args,
)

View File

@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
class Seq2SeqTrainer(Trainer):
def __init__(self, data_args, *args, **kwargs):
def __init__(self, config, data_args, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config = config
self.data_args = data_args
self.max_gen_length = data_args.val_max_target_length
self.pad_token_id = self.model.config.pad_token_id
self.pad_token_id = self.config.pad_token_id
self.vocab_size = self.config.vocab_size
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
if self.args.label_smoothing == 0:
# Same behavior as modeling_bart.py
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
assert logits.shape[-1] == self.model.config.vocab_size
assert logits.shape[-1] == self.vocab_size
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else:
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)

View File

@ -7,7 +7,7 @@ import pickle
import socket
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Union
from typing import Callable, Dict, Iterable, List, Tuple, Union
import git
import numpy as np
@ -19,8 +19,9 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from sentence_splitter import add_newline_to_end_of_each_sentence
from transformers import BartTokenizer
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
from transformers.file_utils import cached_property
from transformers.modeling_bart import shift_tokens_right
try:
@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
def non_pad_len(tokens: np.ndarray) -> int:
return np.count_nonzero(tokens != tokenizer.pad_token_id)
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
pred_str = lmap(str.strip, pred_str)
label_str = lmap(str.strip, label_str)
return pred_str, label_str
def summarization_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
rouge: Dict = calculate_rouge(pred_str, label_str)
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
rouge.update({"gen_len": summ_len})
return rouge
def translation_metrics(pred: EvalPrediction) -> Dict:
pred_str, label_str = decode_pred(pred)
bleu: Dict = calculate_bleu(pred_str, label_str)
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
bleu.update({"gen_len": gen_len})
return bleu
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
return compute_metrics_fn
def trim_batch(
input_ids,
pad_token_id,
@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
return batch_encoding
class Seq2SeqDataCollator:
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
self.tokenizer = tokenizer
self.pad_token_id = tokenizer.pad_token_id
assert (
self.pad_token_id is not None
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
self.data_args = data_args
self.tpu_num_cores = tpu_num_cores
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
def __call__(self, batch) -> Dict[str, torch.Tensor]:
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
batch = self._encode(batch)
input_ids, attention_mask, labels = (
batch["input_ids"],
batch["attention_mask"],
batch["labels"],
)
else:
input_ids = torch.stack([x["input_ids"] for x in batch])
attention_mask = torch.stack([x["attention_mask"] for x in batch])
labels = torch.stack([x["labels"] for x in batch])
labels = trim_batch(labels, self.pad_token_id)
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
if isinstance(self.tokenizer, T5Tokenizer):
decoder_input_ids = self._shift_right_t5(labels)
else:
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"labels": labels,
}
return batch
def _shift_right_t5(self, input_ids):
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = self.pad_token_id
return shifted_input_ids
def _encode(self, batch) -> Dict[str, torch.Tensor]:
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.data_args.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
tgt_lang=self.data_args.tgt_lang,
max_length=self.data_args.max_source_length,
max_target_length=self.data_args.max_target_length,
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."