mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples] summarization/bart/finetune.py supports t5 (#3824)
renames `run_bart_sum.py` to `finetune.py`
This commit is contained in:
parent
0cec4fab7d
commit
f0c96fafd1
@ -19,7 +19,7 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BartSystem(BaseTransformer):
|
||||
class SummarizationTrainer(BaseTransformer):
|
||||
|
||||
mode = "language-modeling"
|
||||
|
||||
@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
|
||||
return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
# NOTE: this generation will not use the cache.
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
# NOTE: these kwargs get more speed and lower quality summaries than those in evaluate_cnn.py.
|
||||
# NOTE: the following kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
|
||||
generated_ids = self.model.generate(
|
||||
source_ids,
|
||||
source_mask,
|
||||
input_ids=source_ids,
|
||||
attention_mask=source_mask,
|
||||
num_beams=1,
|
||||
max_length=80,
|
||||
repetition_penalty=2.5,
|
||||
length_penalty=1.0,
|
||||
early_stopping=True,
|
||||
use_cache=True,
|
||||
)
|
||||
preds = [
|
||||
self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
@ -161,20 +161,20 @@ def main(args):
|
||||
if not args.output_dir:
|
||||
args.output_dir = os.path.join("./results", f"{args.task}_{args.model_type}_{time.strftime('%Y%m%d_%H%M%S')}",)
|
||||
os.makedirs(args.output_dir)
|
||||
model = BartSystem(args)
|
||||
model = SummarizationTrainer(args)
|
||||
trainer = generic_train(model, args)
|
||||
|
||||
# Optionally, predict on dev set and write to output_dir
|
||||
if args.do_predict:
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
|
||||
BartSystem.load_from_checkpoint(checkpoints[-1])
|
||||
SummarizationTrainer.load_from_checkpoint(checkpoints[-1])
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
add_generic_args(parser, os.getcwd())
|
||||
parser = BartSystem.add_model_specific_args(parser, os.getcwd())
|
||||
parser = SummarizationTrainer.add_model_specific_args(parser, os.getcwd())
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
|
||||
# Add parent directory to python path to access transformer_base.py
|
||||
export PYTHONPATH="../../":"${PYTHONPATH}"
|
||||
|
||||
python run_bart_sum.py \
|
||||
python finetune.py \
|
||||
--data_dir=./cnn-dailymail/cnn_dm \
|
||||
--model_type=bart \
|
||||
--model_name_or_path=bart-large \
|
||||
|
@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
|
||||
|
||||
# Add parent directory to python path to access transformer_base.py and utils.py
|
||||
export PYTHONPATH="../../":"${PYTHONPATH}"
|
||||
python run_bart_sum.py \
|
||||
python finetune.py \
|
||||
--data_dir=cnn_tiny/ \
|
||||
--model_type=bart \
|
||||
--model_name_or_path=sshleifer/bart-tiny-random \
|
||||
|
@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
|
||||
from transformers import BartTokenizer
|
||||
|
||||
from .evaluate_cnn import run_generate
|
||||
from .run_bart_sum import main
|
||||
from .finetune import main
|
||||
from .utils import SummarizationDataset
|
||||
|
||||
|
||||
@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
|
||||
args_d.update(
|
||||
data_dir=tmp_dir, model_type="bart", train_batch_size=2, eval_batch_size=2, n_gpu=0, output_dir=output_dir,
|
||||
)
|
||||
main(argparse.Namespace(**args_d))
|
||||
args_d.update({"do_train": False, "do_predict": True})
|
||||
main(argparse.Namespace(**args_d))
|
||||
|
||||
args = argparse.Namespace(**args_d)
|
||||
main(args)
|
||||
def test_t5_run_sum_cli(self):
|
||||
args_d: dict = DEFAULT_ARGS.copy()
|
||||
tmp_dir = make_test_data_dir()
|
||||
output_dir = tempfile.mkdtemp(prefix="output_")
|
||||
args_d.update(
|
||||
data_dir=tmp_dir,
|
||||
model_type="t5",
|
||||
model_name_or_path="patrickvonplaten/t5-tiny-random",
|
||||
train_batch_size=2,
|
||||
eval_batch_size=2,
|
||||
n_gpu=0,
|
||||
output_dir=output_dir,
|
||||
do_predict=True,
|
||||
)
|
||||
main(argparse.Namespace(**args_d))
|
||||
# args_d.update({"do_train": False, "do_predict": True})
|
||||
# main(argparse.Namespace(**args_d))
|
||||
|
||||
def test_bart_summarization_dataset(self):
|
||||
tmp_dir = Path(tempfile.gettempdir())
|
||||
|
@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
|
||||
wc -l cnn_articles_reference_summaries.txt # should print 11490
|
||||
```
|
||||
|
||||
### Usage
|
||||
### Generating Summaries
|
||||
|
||||
To create summaries for each article in dataset, run:
|
||||
```bash
|
||||
@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
|
||||
```
|
||||
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
|
||||
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in ``rouge_score.txt``.
|
||||
|
||||
|
||||
### Finetuning
|
||||
Pass model_type=t5 and model `examples/summarization/bart/finetune.py`
|
||||
|
Loading…
Reference in New Issue
Block a user