Fix rag finetuning + add finetuning test (#8585)

* replace init_ddp_connection for index init

* style

* add finetune test

* add test data

* move generate tensors to device

* add test on EM metric

* style

* allow multi process test

* keep gloo process group for retrieval

* add multi-gpu test

* use custom accelerator

* clean test finetune

* minor

* style

* style

* typo

* use python call instead of imported main fumction

* return_dict fix in modeling_rag

* use float32 in retrieval

* store as float32 as well in the custom knowledge dataset example

* style

* rename to finetune_rag

* style

* update readme

* rename utils and callbacks to utils_rag and callbacks_rag

* fix test

* patrick's comments

* generate dummy data in the finetue test script

* remove dummy data files

* style
This commit is contained in:
Quentin Lhoest 2020-11-20 19:05:03 +01:00 committed by GitHub
parent 63e91f5fde
commit 8062fa63c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 200 additions and 56 deletions

View File

@ -384,6 +384,8 @@ def generic_train(
train_params["distributed_backend"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
trainer = pl.Trainer.from_argparse_args(
args,

View File

@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p
Such contextualized inputs are passed to the generator.
Read more about RAG at https://arxiv.org/abs/2005.11401.
# Finetuning
# Finetuning
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
```bash
@ -20,10 +20,10 @@ test.source
test.target
```
A sample finetuning command (run ` ./examples/rag/finetune.py --help` to list all available options):
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
```bash
python examples/rag/finetune.py \
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
--dest path/to/checkpoint
```
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune.py` script.
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
# Evaluation
@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \
--print_predictions \
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
```
# Use your own knowledge source
By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
```bash
python examples/rag/use_own_knowledge_dataset.py \
--csv_path path/to/my_csv \
--output_dir path/to/my_knowledge_dataset \
```
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
```bash
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \
--model_type rag_sequence \
--fp16 \
--gpus 8
--index_name custom
--passages_path path/to/data/my_knowledge_dataset
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
```

View File

@ -8,7 +8,7 @@ import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from utils import save_json
from utils_rag import save_json
def count_trainable_parameters(model):
@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
monitor=f"val_{metric}",
mode="max",
save_top_k=3,
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
)
return checkpoint_callback

View File

@ -40,7 +40,6 @@ class RagPyTorchDistributedRetriever(RagRetriever):
generator_tokenizer=generator_tokenizer,
index=index,
)
self.process_group = None
def init_retrieval(self, distributed_port: int):

View File

@ -1,12 +1,10 @@
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
import argparse
import glob
import logging
import os
import sys
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
@ -15,29 +13,31 @@ import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
from torch.utils.data import DataLoader
from transformers import (
AutoConfig,
AutoTokenizer,
BartForConditionalGeneration,
BatchEncoding,
RagConfig,
RagSequenceForGeneration,
RagTokenForGeneration,
RagTokenizer,
T5ForConditionalGeneration,
get_linear_schedule_with_warmup,
)
from transformers import logging as transformers_logging
from callbacks import ( # noqa: E402 # isort:skipq
from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback,
get_early_stopping_callback,
Seq2SeqLoggingCallback,
)
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
from utils import ( # noqa: E402 # isort:skip
from utils_rag import ( # noqa: E402 # isort:skip
calculate_exact_match,
flatten_list,
get_git_info,
@ -67,6 +67,30 @@ class AttrDict(dict):
self.__dict__ = self
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
# is no longer used, and is moved into DDPAccelerator instead.
# We override DDPAccelerator to add our custom logic for initializing the
# retriever.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
class CustomAccel(DDPAccelerator):
def __init__(self, trainer=None, **kwargs):
# Trainer is set later.
super().__init__(trainer, **kwargs)
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model:
module.model.rag.retriever.init_retrieval(self.distributed_port)
class GenerativeQAModule(BaseTransformer):
mode = "generative_qa"
loss_names = ["loss"]
@ -91,23 +115,24 @@ class GenerativeQAModule(BaseTransformer):
config = config_class.from_pretrained(hparams.model_name_or_path)
# set retriever parameters
config.index_name = args.index_name or config.index_name
config.passages_path = args.passages_path or config.passages_path
config.index_path = args.index_path or config.index_path
config.index_name = hparams.index_name or config.index_name
config.passages_path = hparams.passages_path or config.passages_path
config.index_path = hparams.index_path or config.index_path
config.use_dummy_dataset = hparams.use_dummy_dataset
# set extra_model_params for generator configs and load_model
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
if self.is_rag_model:
if args.prefix is not None:
config.generator.prefix = args.prefix
if hparams.prefix is not None:
config.generator.prefix = hparams.prefix
config.label_smoothing = hparams.label_smoothing
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
prefix = config.question_encoder.prefix
else:
if args.prefix is not None:
config.prefix = args.prefix
if hparams.prefix is not None:
config.prefix = hparams.prefix
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
prefix = config.prefix
@ -152,11 +177,9 @@ class GenerativeQAModule(BaseTransformer):
self.num_workers = hparams.num_workers
self.distributed_port = self.hparams.distributed_port
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if self.is_rag_model:
# For single GPU training, init_ddp_connection is not called.
# So we need to initialize the retrievers here.
if hparams.gpus <= 1:
self.model.retriever.init_retrieval(self.distributed_port)
def forward(self, input_ids, **kwargs):
@ -270,6 +293,7 @@ class GenerativeQAModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict:
start_time = time.time()
batch = BatchEncoding(batch).to(device=self.model.device)
generated_ids = self.model.generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
@ -322,17 +346,6 @@ class GenerativeQAModule(BaseTransformer):
def train_dataloader(self) -> DataLoader:
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
t_total = (
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
// self.hparams.accumulate_grad_batches
* float(self.hparams.max_epochs)
)
scheduler = get_linear_schedule_with_warmup(
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
)
if max(scheduler.get_last_lr()) > 0:
warnings.warn("All learning rates are 0")
self.lr_scheduler = scheduler
return dataloader
def val_dataloader(self) -> DataLoader:
@ -429,10 +442,24 @@ class GenerativeQAModule(BaseTransformer):
default=None,
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
parser.add_argument(
"--use_dummy_dataset",
type=bool,
default=False,
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
)
return parser
def main(args, model=None) -> GenerativeQAModule:
def main(args=None, model=None) -> GenerativeQAModule:
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
args = args or parser.parse_args()
Path(args.output_dir).mkdir(exist_ok=True)
if model is None:
model: GenerativeQAModule = GenerativeQAModule(args)
@ -461,6 +488,7 @@ def main(args, model=None) -> GenerativeQAModule:
if args.early_stopping_patience >= 0
else False
)
trainer: pl.Trainer = generic_train(
model,
args,
@ -468,31 +496,17 @@ def main(args, model=None) -> GenerativeQAModule:
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=logger,
accelerator=CustomAccel() if args.gpus > 1 else None,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
if not args.do_predict:
return model
model.hparams.test_checkpoint = ""
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
if checkpoints:
model.hparams.test_checkpoint = checkpoints[-1]
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
trainer.logger.log_hyperparams(model.hparams)
# test() without a model tests using the best checkpoint automatically
trainer.test()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
parser = GenerativeQAModule.add_retriever_specific_args(parser)
args = parser.parse_args()
main(args)
main()

View File

@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}"
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune.sh --help to see all the possible options
python examples/rag/finetune.py \
python examples/rag/finetune_rag.py \
--data_dir $DATA_DIR \
--output_dir $OUTPUT_DIR \
--model_name_or_path $MODEL_NAME_OR_PATH \

View File

@ -0,0 +1,96 @@
import json
import logging
import os
import sys
from pathlib import Path
import finetune_rag
from transformers.file_utils import is_apex_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
require_torch_gpu,
require_torch_multi_gpu,
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
class RagFinetuneExampleTests(TestCasePlus):
def _create_dummy_data(self, data_dir):
os.makedirs(data_dir, exist_ok=True)
contents = {"source": "What is love ?", "target": "life"}
n_lines = {"train": 12, "val": 2, "test": 2}
for split in ["train", "test", "val"]:
for field in ["source", "target"]:
content = "\n".join([contents[field]] * n_lines[split])
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
f.write(content)
def _run_finetune(self, gpus: int):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
output_dir = os.path.join(tmp_dir, "output")
data_dir = os.path.join(tmp_dir, "data")
self._create_dummy_data(data_dir=data_dir)
testargs = f"""
--data_dir {data_dir} \
--output_dir {output_dir} \
--model_name_or_path facebook/rag-sequence-base \
--model_type rag_sequence \
--do_train \
--do_predict \
--n_val -1 \
--val_check_interval 1.0 \
--train_batch_size 2 \
--eval_batch_size 1 \
--max_source_length 25 \
--max_target_length 25 \
--val_max_target_length 25 \
--test_max_target_length 25 \
--label_smoothing 0.1 \
--dropout 0.1 \
--attention_dropout 0.1 \
--weight_decay 0.001 \
--adam_epsilon 1e-08 \
--max_grad_norm 0.1 \
--lr_scheduler polynomial \
--learning_rate 3e-04 \
--num_train_epochs 1 \
--warmup_steps 4 \
--gradient_accumulation_steps 1 \
--distributed-port 8787 \
--use_dummy_dataset 1 \
""".split()
if gpus > 0:
testargs.append(f"--gpus={gpus}")
if is_apex_available():
testargs.append("--fp16")
else:
testargs.append("--gpus=0")
testargs.append("--distributed_backend=ddp_cpu")
testargs.append("--num_processes=2")
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
execute_subprocess_async(cmd, env=self.get_env())
metrics_save_path = os.path.join(output_dir, "metrics.json")
with open(metrics_save_path) as f:
result = json.load(f)
return result
@require_torch_gpu
def test_finetune_gpu(self):
result = self._run_finetune(gpus=1)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
@require_torch_multi_gpu
def test_finetune_multigpu(self):
result = self._run_finetune(gpus=2)
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)

View File

@ -7,7 +7,7 @@ from tempfile import TemporaryDirectory
from typing import List, Optional
import torch
from datasets import load_dataset
from datasets import Features, Sequence, Value, load_dataset
import faiss
from transformers import (
@ -82,10 +82,14 @@ def main(
# And compute the embeddings
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
new_features = Features(
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
) # optional, save as float32 instead of float64 to save space
dataset = dataset.map(
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
batched=True,
batch_size=processing_args.batch_size,
features=new_features,
)
# And finally save your dataset

View File

@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel):
if encoder_outputs is None:
if has_to_retrieve:
question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask)
question_enc_outputs = self.question_encoder(
input_ids, attention_mask=attention_mask, return_dict=True
)
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
retriever_outputs = self.retriever(
@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel):
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
return_dict=True,
)
if not has_to_retrieve:

View File

@ -196,7 +196,7 @@ class HFIndexBase(Index):
self.dataset = dataset
self._index_initialized = index_initialized
self._check_dataset_format(with_index=index_initialized)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
def _check_dataset_format(self, with_index: bool):
if not isinstance(self.dataset, Dataset):