mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[examples/s2s] clean up finetune_trainer (#7509)
This commit is contained in:
parent
bd2621583b
commit
72d363d979
@ -2,37 +2,29 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
BartTokenizer,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
T5Tokenizer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from transformers.trainer_utils import EvaluationStrategy
|
||||
from utils import (
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataCollator,
|
||||
Seq2SeqDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
build_compute_metrics_fn,
|
||||
freeze_embeds,
|
||||
freeze_params,
|
||||
lmap,
|
||||
save_json,
|
||||
trim_batch,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
@ -41,66 +33,6 @@ from utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
assert self.pad_token_id is not None, "self.pad_token_id must be defined"
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["labels"],
|
||||
)
|
||||
else:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||
labels = torch.stack([x["labels"] for x in batch])
|
||||
|
||||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.data_args.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.data_args.tgt_lang,
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
@ -271,34 +203,6 @@ def main():
|
||||
), "mBart requires --tgt_lang and --src_lang"
|
||||
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.tgt_lang]
|
||||
|
||||
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
pred_str = lmap(str.strip, pred_str)
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
if model_args.freeze_embeds:
|
||||
freeze_embeds(model)
|
||||
if model_args.freeze_encoder:
|
||||
@ -349,13 +253,17 @@ def main():
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
compute_metrics_fn = (
|
||||
build_compute_metrics_fn(data_args.task, tokenizer) if training_args.predict_with_generate else None
|
||||
)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=build_compute_metrics_fn(data_args.task) if training_args.predict_with_generate else None,
|
||||
compute_metrics=compute_metrics_fn,
|
||||
data_args=data_args,
|
||||
)
|
||||
|
||||
|
@ -20,11 +20,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(self, data_args, *args, **kwargs):
|
||||
def __init__(self, config, data_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.config = config
|
||||
self.data_args = data_args
|
||||
self.max_gen_length = data_args.val_max_target_length
|
||||
self.pad_token_id = self.model.config.pad_token_id
|
||||
self.pad_token_id = self.config.pad_token_id
|
||||
self.vocab_size = self.config.vocab_size
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
@ -53,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
if self.args.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||
assert logits.shape[-1] == self.model.config.vocab_size
|
||||
assert logits.shape[-1] == self.vocab_size
|
||||
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
|
@ -7,7 +7,7 @@ import pickle
|
||||
import socket
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Union
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
@ -19,8 +19,9 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
@ -62,6 +63,35 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
||||
|
||||
|
||||
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
pred_str = lmap(str.strip, pred_str)
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
@ -230,6 +260,68 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
return batch_encoding
|
||||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
assert (
|
||||
self.pad_token_id is not None
|
||||
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.add_prefix_space = isinstance(tokenizer, BartTokenizer)
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["labels"],
|
||||
)
|
||||
else:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||
labels = torch.stack([x["labels"] for x in batch])
|
||||
|
||||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
src_lang=self.data_args.src_lang,
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
tgt_lang=self.data_args.tgt_lang,
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
add_prefix_space=self.add_prefix_space,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user