mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix rag finetuning + add finetuning test (#8585)
* replace init_ddp_connection for index init * style * add finetune test * add test data * move generate tensors to device * add test on EM metric * style * allow multi process test * keep gloo process group for retrieval * add multi-gpu test * use custom accelerator * clean test finetune * minor * style * style * typo * use python call instead of imported main fumction * return_dict fix in modeling_rag * use float32 in retrieval * store as float32 as well in the custom knowledge dataset example * style * rename to finetune_rag * style * update readme * rename utils and callbacks to utils_rag and callbacks_rag * fix test * patrick's comments * generate dummy data in the finetue test script * remove dummy data files * style
This commit is contained in:
parent
63e91f5fde
commit
8062fa63c5
@ -384,6 +384,8 @@ def generic_train(
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
|
@ -7,8 +7,8 @@ to the retriever to extract relevant context documents. The documents are then p
|
||||
Such contextualized inputs are passed to the generator.
|
||||
|
||||
Read more about RAG at https://arxiv.org/abs/2005.11401.
|
||||
# Finetuning
|
||||
|
||||
# Finetuning
|
||||
|
||||
Our finetuning logic is based on scripts from [`examples/seq2seq`](https://github.com/huggingface/transformers/tree/master/examples/seq2seq). We accept training data in the same format as specified there - we expect a directory consisting of 6 text files:
|
||||
```bash
|
||||
@ -20,10 +20,10 @@ test.source
|
||||
test.target
|
||||
```
|
||||
|
||||
A sample finetuning command (run ` ./examples/rag/finetune.py --help` to list all available options):
|
||||
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
|
||||
|
||||
```bash
|
||||
python examples/rag/finetune.py \
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
@ -45,7 +45,7 @@ python examples/rag/consolidate_rag_checkpoint.py \
|
||||
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
||||
--dest path/to/checkpoint
|
||||
```
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune.py` script.
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||
|
||||
|
||||
# Evaluation
|
||||
@ -130,3 +130,29 @@ python examples/rag/eval_rag.py \
|
||||
--print_predictions \
|
||||
--recalculate \ # adding this parameter will force recalculating predictions even if predictions_path already exists
|
||||
```
|
||||
|
||||
# Use your own knowledge source
|
||||
|
||||
By default, RAG uses the English Wikipedia as a knowledge source, known as the 'wiki_dpr' dataset.
|
||||
With `use_custom_knowledge_dataset.py` you can build your own knowledge source, *e.g.* for RAG.
|
||||
|
||||
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
||||
```bash
|
||||
python examples/rag/use_own_knowledge_dataset.py \
|
||||
--csv_path path/to/my_csv \
|
||||
--output_dir path/to/my_knowledge_dataset \
|
||||
```
|
||||
|
||||
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
--index_name custom
|
||||
--passages_path path/to/data/my_knowledge_dataset
|
||||
--index_path path/to/my_knowledge_dataset_hnsw_index.faiss
|
||||
```
|
@ -8,7 +8,7 @@ import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils import save_json
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
@ -38,7 +38,7 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
save_top_k=3,
|
||||
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
return checkpoint_callback
|
||||
|
@ -40,7 +40,6 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
)
|
||||
|
||||
self.process_group = None
|
||||
|
||||
def init_retrieval(self, distributed_port: int):
|
||||
|
@ -1,12 +1,10 @@
|
||||
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -15,29 +13,31 @@ import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
|
||||
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
BartForConditionalGeneration,
|
||||
BatchEncoding,
|
||||
RagConfig,
|
||||
RagSequenceForGeneration,
|
||||
RagTokenForGeneration,
|
||||
RagTokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
from callbacks import ( # noqa: E402 # isort:skipq
|
||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
get_checkpoint_callback,
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from utils import ( # noqa: E402 # isort:skip
|
||||
from utils_rag import ( # noqa: E402 # isort:skip
|
||||
calculate_exact_match,
|
||||
flatten_list,
|
||||
get_git_info,
|
||||
@ -67,6 +67,30 @@ class AttrDict(dict):
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
|
||||
# is no longer used, and is moved into DDPAccelerator instead.
|
||||
# We override DDPAccelerator to add our custom logic for initializing the
|
||||
# retriever.
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
|
||||
|
||||
|
||||
class CustomAccel(DDPAccelerator):
|
||||
def __init__(self, trainer=None, **kwargs):
|
||||
# Trainer is set later.
|
||||
super().__init__(trainer, **kwargs)
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
module = self.trainer.model
|
||||
if self.cluster_environment is None:
|
||||
self.cluster_environment = TorchElasticEnvironment()
|
||||
self.distributed_port = module.hparams.distributed_port
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if module.is_rag_model:
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
mode = "generative_qa"
|
||||
loss_names = ["loss"]
|
||||
@ -91,23 +115,24 @@ class GenerativeQAModule(BaseTransformer):
|
||||
config = config_class.from_pretrained(hparams.model_name_or_path)
|
||||
|
||||
# set retriever parameters
|
||||
config.index_name = args.index_name or config.index_name
|
||||
config.passages_path = args.passages_path or config.passages_path
|
||||
config.index_path = args.index_path or config.index_path
|
||||
config.index_name = hparams.index_name or config.index_name
|
||||
config.passages_path = hparams.passages_path or config.passages_path
|
||||
config.index_path = hparams.index_path or config.index_path
|
||||
config.use_dummy_dataset = hparams.use_dummy_dataset
|
||||
|
||||
# set extra_model_params for generator configs and load_model
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
||||
if self.is_rag_model:
|
||||
if args.prefix is not None:
|
||||
config.generator.prefix = args.prefix
|
||||
if hparams.prefix is not None:
|
||||
config.generator.prefix = hparams.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
if args.prefix is not None:
|
||||
config.prefix = args.prefix
|
||||
if hparams.prefix is not None:
|
||||
config.prefix = hparams.prefix
|
||||
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
prefix = config.prefix
|
||||
@ -152,11 +177,9 @@ class GenerativeQAModule(BaseTransformer):
|
||||
self.num_workers = hparams.num_workers
|
||||
self.distributed_port = self.hparams.distributed_port
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if self.is_rag_model:
|
||||
# For single GPU training, init_ddp_connection is not called.
|
||||
# So we need to initialize the retrievers here.
|
||||
if hparams.gpus <= 1:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
@ -270,6 +293,7 @@ class GenerativeQAModule(BaseTransformer):
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
start_time = time.time()
|
||||
batch = BatchEncoding(batch).to(device=self.model.device)
|
||||
generated_ids = self.model.generate(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
@ -322,17 +346,6 @@ class GenerativeQAModule(BaseTransformer):
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
||||
t_total = (
|
||||
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
||||
// self.hparams.accumulate_grad_batches
|
||||
* float(self.hparams.max_epochs)
|
||||
)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
if max(scheduler.get_last_lr()) > 0:
|
||||
warnings.warn("All learning rates are 0")
|
||||
self.lr_scheduler = scheduler
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
@ -429,10 +442,24 @@ class GenerativeQAModule(BaseTransformer):
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to use the dummy version of the dataset index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args, model=None) -> GenerativeQAModule:
|
||||
def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
|
||||
args = args or parser.parse_args()
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
@ -461,6 +488,7 @@ def main(args, model=None) -> GenerativeQAModule:
|
||||
if args.early_stopping_patience >= 0
|
||||
else False
|
||||
)
|
||||
|
||||
trainer: pl.Trainer = generic_train(
|
||||
model,
|
||||
args,
|
||||
@ -468,31 +496,17 @@ def main(args, model=None) -> GenerativeQAModule:
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
if not args.do_predict:
|
||||
return model
|
||||
|
||||
model.hparams.test_checkpoint = ""
|
||||
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
||||
if checkpoints:
|
||||
model.hparams.test_checkpoint = checkpoints[-1]
|
||||
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
|
||||
trainer.logger.log_hyperparams(model.hparams)
|
||||
|
||||
# test() without a model tests using the best checkpoint automatically
|
||||
trainer.test()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
main()
|
@ -4,7 +4,7 @@ export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune.py \
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
96
examples/rag/test_finetune_rag.py
Normal file
96
examples/rag/test_finetune_rag.py
Normal file
@ -0,0 +1,96 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import finetune_rag
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class RagFinetuneExampleTests(TestCasePlus):
|
||||
def _create_dummy_data(self, data_dir):
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
contents = {"source": "What is love ?", "target": "life"}
|
||||
n_lines = {"train": 12, "val": 2, "test": 2}
|
||||
for split in ["train", "test", "val"]:
|
||||
for field in ["source", "target"]:
|
||||
content = "\n".join([contents[field]] * n_lines[split])
|
||||
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
def _run_finetune(self, gpus: int):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
output_dir = os.path.join(tmp_dir, "output")
|
||||
data_dir = os.path.join(tmp_dir, "data")
|
||||
self._create_dummy_data(data_dir=data_dir)
|
||||
|
||||
testargs = f"""
|
||||
--data_dir {data_dir} \
|
||||
--output_dir {output_dir} \
|
||||
--model_name_or_path facebook/rag-sequence-base \
|
||||
--model_type rag_sequence \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 1.0 \
|
||||
--train_batch_size 2 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 25 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-04 \
|
||||
--num_train_epochs 1 \
|
||||
--warmup_steps 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed-port 8787 \
|
||||
--use_dummy_dataset 1 \
|
||||
""".split()
|
||||
|
||||
if gpus > 0:
|
||||
testargs.append(f"--gpus={gpus}")
|
||||
if is_apex_available():
|
||||
testargs.append("--fp16")
|
||||
else:
|
||||
testargs.append("--gpus=0")
|
||||
testargs.append("--distributed_backend=ddp_cpu")
|
||||
testargs.append("--num_processes=2")
|
||||
|
||||
cmd = [sys.executable, str(Path(finetune_rag.__file__).resolve())] + testargs
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
metrics_save_path = os.path.join(output_dir, "metrics.json")
|
||||
with open(metrics_save_path) as f:
|
||||
result = json.load(f)
|
||||
return result
|
||||
|
||||
@require_torch_gpu
|
||||
def test_finetune_gpu(self):
|
||||
result = self._run_finetune(gpus=1)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_finetune_multigpu(self):
|
||||
result = self._run_finetune(gpus=2)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
@ -7,7 +7,7 @@ from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
@ -82,10 +82,14 @@ def main(
|
||||
# And compute the embeddings
|
||||
ctx_encoder = DPRContextEncoder.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name).to(device=device)
|
||||
ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(rag_example_args.dpr_ctx_encoder_model_name)
|
||||
new_features = Features(
|
||||
{"text": Value("string"), "title": Value("string"), "embeddings": Sequence(Value("float32"))}
|
||||
) # optional, save as float32 instead of float64 to save space
|
||||
dataset = dataset.map(
|
||||
partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=ctx_tokenizer),
|
||||
batched=True,
|
||||
batch_size=processing_args.batch_size,
|
||||
features=new_features,
|
||||
)
|
||||
|
||||
# And finally save your dataset
|
||||
|
@ -556,7 +556,9 @@ class RagModel(RagPreTrainedModel):
|
||||
if encoder_outputs is None:
|
||||
|
||||
if has_to_retrieve:
|
||||
question_enc_outputs = self.question_encoder(input_ids, attention_mask=attention_mask)
|
||||
question_enc_outputs = self.question_encoder(
|
||||
input_ids, attention_mask=attention_mask, return_dict=True
|
||||
)
|
||||
question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
|
||||
|
||||
retriever_outputs = self.retriever(
|
||||
@ -616,6 +618,7 @@ class RagModel(RagPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
if not has_to_retrieve:
|
||||
|
@ -196,7 +196,7 @@ class HFIndexBase(Index):
|
||||
self.dataset = dataset
|
||||
self._index_initialized = index_initialized
|
||||
self._check_dataset_format(with_index=index_initialized)
|
||||
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
|
||||
dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
|
||||
|
||||
def _check_dataset_format(self, with_index: bool):
|
||||
if not isinstance(self.dataset, Dataset):
|
||||
|
Loading…
Reference in New Issue
Block a user