mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00

* added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * path fix * Formatting / renaming prior to actual work * added rag WIP * Formatting / renaming prior to actual work * First commit * improve comments * Retrieval evaluation scripts * refactor to include modeling outputs + MPI retriever * Fix rag-token model + refactor * Various fixes + finetuning logic * use_bos fix * Retrieval refactor * Finetuning refactoring and cleanup * Add documentation and cleanup * Remove set_up_rag_env.sh file * Fix retrieval wit HF index * Fix import errors * Fix quality errors * Refactor as per suggestions in https://github.com/huggingface/transformers/pull/6813#issuecomment-687208867 * fix quality * Fix RAG Sequence generation * minor cleanup plus initial tests * fix test * fix tests 2 * Comments fix * post-merge fixes * Improve readme + post-rebase refactor * Extra dependencied for tests * Fix tests * Fix tests 2 * Refactor test requirements * Fix tests 3 * Post-rebase refactor * rename nlp->datasets * RAG integration tests * add tokenizer to slow integration test and allow retriever to run on cpu * add tests; fix position ids warning * change structure * change structure * add from encoder generator * save working solution * make all integration tests pass * add RagTokenizer.save/from_pretrained and RagRetriever.save/from_pretrained * don't save paths * delete unnecessary imports * pass config to AutoTokenizer.from_pretrained for Rag tokenizers * init wiki_dpr only once * hardcode legacy index and passages paths (todo: add the right urls) * finalize config * finalize retriver api and config api * LegacyIndex index download refactor * add dpr to autotokenizer * make from pretrained more flexible * fix ragfortokengeneration * small name changes in tokenizer * add labels to models * change default index name * add retrieval tests * finish token generate * align test with previous version and make all tests pass * add tests * finalize tests * implement thoms suggestions * add first version of test * make first tests work * make retriever platform agnostic * naming * style * add legacy index URL * docstrings + simple retrieval test for distributed * clean model api * add doc_ids to retriever's outputs * fix retrieval tests * finish model outputs * finalize model api * fix generate problem for rag * fix generate for other modles * fix some tests * save intermediate * set generate to default * big refactor generate * delete rag_api * correct pip faiss install * fix auto tokenization test * fix faiss install * fix test * move the distributed logic to examples * model page * docs * finish tests * fix dependencies * fix import in __init__ * Refactor eval_rag and finetune scripts * start docstring * add psutil to test * fix tf test * move require torch to top * fix retrieval test * align naming * finish automodel * fix repo consistency * test ragtokenizer save/load * add rag model output docs * fix ragtokenizer save/load from pretrained * fix tokenizer dir * remove torch in retrieval * fix docs * fixe finetune scripts * finish model docs * finish docs * remove auto model for now * add require torch * remove solved todos * integrate sylvains suggestions * sams comments * correct mistake on purpose * improve README * Add generation test cases * fix rag token * clean token generate * fix test * add note to test * fix attention mask * add t5 test for rag * Fix handling prefix in finetune.py * don't overwrite index_name Co-authored-by: Patrick Lewis <plewis@fb.com> Co-authored-by: Aleksandra Piktus <piktus@devfair0141.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5102.h2.fair> Co-authored-by: Aleksandra Piktus <piktus@learnfair5067.h2.fair> Co-authored-by: Your Name <you@example.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
475 lines
19 KiB
Python
475 lines
19 KiB
Python
"""Finetuning script for RAG models. Adapted from examples.seq2seq.finetune.py"""
|
|
|
|
import argparse
|
|
import glob
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.utils.data import DataLoader
|
|
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoTokenizer,
|
|
BartForConditionalGeneration,
|
|
RagConfig,
|
|
RagSequenceForGeneration,
|
|
RagTokenForGeneration,
|
|
RagTokenizer,
|
|
T5ForConditionalGeneration,
|
|
get_linear_schedule_with_warmup,
|
|
)
|
|
from transformers import logging as transformers_logging
|
|
|
|
|
|
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
|
|
|
from examples.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa: E402 # isort:skip
|
|
from examples.rag.callbacks import get_checkpoint_callback # noqa: E402 # isort:skip
|
|
from examples.rag.distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
|
from examples.rag.utils import ( # noqa: E402 # isort:skip
|
|
Seq2SeqDataset,
|
|
calculate_exact_match,
|
|
is_rag_model,
|
|
set_extra_model_params,
|
|
)
|
|
from examples.seq2seq.callbacks import Seq2SeqLoggingCallback, get_early_stopping_callback # noqa: E402 # isort:skip
|
|
from examples.seq2seq.utils import ( # noqa: E402 # isort:skip
|
|
flatten_list,
|
|
get_git_info,
|
|
lmap,
|
|
pickle_save,
|
|
save_git_info,
|
|
save_json,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
transformers_logging.set_verbosity_info()
|
|
|
|
|
|
class AttrDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
|
|
class GenerativeQAModule(BaseTransformer):
|
|
mode = "generative_qa"
|
|
loss_names = ["loss"]
|
|
metric_names = ["em"]
|
|
val_metric = "em"
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
# when loading from a pytorch lightning checkpoint, hparams are passed as dict
|
|
if isinstance(hparams, dict):
|
|
hparams = AttrDict(hparams)
|
|
if hparams.model_type == "rag_sequence":
|
|
self.model_class = RagSequenceForGeneration
|
|
elif hparams.model_type == "rag_token":
|
|
self.model_class = RagTokenForGeneration
|
|
elif hparams.model_type == "bart":
|
|
self.model_class = BartForConditionalGeneration
|
|
else:
|
|
self.model_class = T5ForConditionalGeneration
|
|
self.is_rag_model = is_rag_model(hparams.model_type)
|
|
|
|
config_class = RagConfig if self.is_rag_model else AutoConfig
|
|
config = config_class.from_pretrained(hparams.model_name_or_path)
|
|
|
|
# set extra_model_params for generator configs and load_model
|
|
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "attention_dropout", "dropout")
|
|
if self.is_rag_model:
|
|
if args.prefix is not None:
|
|
config.generator.prefix = args.prefix
|
|
config.label_smoothing = hparams.label_smoothing
|
|
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
|
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path)
|
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
|
prefix = config.question_encoder.prefix
|
|
else:
|
|
if args.prefix is not None:
|
|
config.prefix = args.prefix
|
|
hparams, config = set_extra_model_params(extra_model_params, hparams, config)
|
|
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config)
|
|
prefix = config.prefix
|
|
|
|
tokenizer = (
|
|
RagTokenizer.from_pretrained(hparams.model_name_or_path)
|
|
if self.is_rag_model
|
|
else AutoTokenizer.from_pretrained(hparams.model_name_or_path)
|
|
)
|
|
|
|
super().__init__(hparams, config=config, tokenizer=tokenizer, model=model)
|
|
|
|
save_git_info(self.hparams.output_dir)
|
|
self.output_dir = Path(self.hparams.output_dir)
|
|
self.metrics_save_path = Path(self.output_dir) / "metrics.json"
|
|
self.hparams_save_path = Path(self.output_dir) / "hparams.pkl"
|
|
pickle_save(self.hparams, self.hparams_save_path)
|
|
self.step_count = 0
|
|
self.metrics = defaultdict(list)
|
|
|
|
self.dataset_kwargs: dict = dict(
|
|
data_dir=self.hparams.data_dir,
|
|
max_source_length=self.hparams.max_source_length,
|
|
prefix=prefix or "",
|
|
)
|
|
n_observations_per_split = {
|
|
"train": self.hparams.n_train,
|
|
"val": self.hparams.n_val,
|
|
"test": self.hparams.n_test,
|
|
}
|
|
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
|
|
|
self.target_lens = {
|
|
"train": self.hparams.max_target_length,
|
|
"val": self.hparams.val_max_target_length,
|
|
"test": self.hparams.test_max_target_length,
|
|
}
|
|
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
|
assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}"
|
|
|
|
self.hparams.git_sha = get_git_info()["repo_sha"]
|
|
self.num_workers = hparams.num_workers
|
|
self.distributed_port = self.hparams.distributed_port
|
|
|
|
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
|
logger.info("Custom init_ddp_connection.")
|
|
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
|
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
|
if self.is_rag_model:
|
|
self.model.retriever.init_retrieval(self.distributed_port)
|
|
|
|
def forward(self, input_ids, **kwargs):
|
|
return self.model(input_ids, **kwargs)
|
|
|
|
def ids_to_clean_text(self, generated_ids: List[int]):
|
|
gen_text = self.tokenizer.batch_decode(
|
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
return lmap(str.strip, gen_text)
|
|
|
|
def _step(self, batch: dict) -> Tuple:
|
|
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
|
|
|
rag_kwargs = {}
|
|
if isinstance(self.model, T5ForConditionalGeneration):
|
|
decoder_input_ids = self.model._shift_right(target_ids)
|
|
lm_labels = target_ids
|
|
elif isinstance(self.model, BartForConditionalGeneration):
|
|
decoder_input_ids = target_ids[:, :-1].contiguous()
|
|
lm_labels = target_ids[:, 1:].clone()
|
|
else:
|
|
assert self.is_rag_model
|
|
generator = self.model.rag.generator
|
|
if isinstance(generator, T5ForConditionalGeneration):
|
|
decoder_start_token_id = generator.config.decoder_start_token_id
|
|
decoder_input_ids = (
|
|
torch.cat(
|
|
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
|
|
dim=1,
|
|
)
|
|
if target_ids.shape[0] < self.target_lens["train"]
|
|
else generator._shift_right(target_ids)
|
|
)
|
|
elif isinstance(generator, BartForConditionalGeneration):
|
|
decoder_input_ids = target_ids
|
|
lm_labels = decoder_input_ids
|
|
rag_kwargs["reduce_loss"] = True
|
|
|
|
assert decoder_input_ids is not None
|
|
|
|
outputs = self(
|
|
source_ids,
|
|
attention_mask=source_mask,
|
|
decoder_input_ids=decoder_input_ids,
|
|
use_cache=False,
|
|
labels=lm_labels,
|
|
return_dict=True,
|
|
**rag_kwargs,
|
|
)
|
|
|
|
loss = outputs["loss"]
|
|
return (loss,)
|
|
|
|
@property
|
|
def pad(self) -> int:
|
|
raise NotImplementedError("pad not implemented")
|
|
|
|
def training_step(self, batch, batch_idx) -> Dict:
|
|
loss_tensors = self._step(batch)
|
|
|
|
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
|
# tokens per batch
|
|
tgt_pad_token_id = (
|
|
self.tokenizer.generator.pad_token_id
|
|
if isinstance(self.tokenizer, RagTokenizer)
|
|
else self.tokenizer.pad_token_id
|
|
)
|
|
src_pad_token_id = (
|
|
self.tokenizer.question_encoder.pad_token_id
|
|
if isinstance(self.tokenizer, RagTokenizer)
|
|
else self.tokenizer.pad_token_id
|
|
)
|
|
logs["tpb"] = (
|
|
batch["input_ids"].ne(src_pad_token_id).sum() + batch["decoder_input_ids"].ne(tgt_pad_token_id).sum()
|
|
)
|
|
|
|
return {"loss": loss_tensors[0], "log": logs}
|
|
|
|
def validation_step(self, batch, batch_idx) -> Dict:
|
|
return self._generative_step(batch)
|
|
|
|
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
|
self.step_count += 1
|
|
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
|
loss = losses["loss"]
|
|
gen_metrics = {
|
|
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
|
}
|
|
metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
|
|
gen_metrics.update({k: v.item() for k, v in losses.items()})
|
|
|
|
# fix for https://github.com/PyTorchLightning/pytorch-lightning/issues/2424
|
|
if dist.is_initialized():
|
|
dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM)
|
|
metrics_tensor = metrics_tensor / dist.get_world_size()
|
|
gen_metrics.update({self.val_metric: metrics_tensor.item()})
|
|
|
|
losses.update(gen_metrics)
|
|
metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
|
metrics["step_count"] = self.step_count
|
|
self.save_metrics(metrics, prefix) # writes to self.metrics_save_path
|
|
preds = flatten_list([x["preds"] for x in outputs])
|
|
return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": metrics_tensor}
|
|
|
|
def save_metrics(self, latest_metrics, type_path) -> None:
|
|
self.metrics[type_path].append(latest_metrics)
|
|
save_json(self.metrics, self.metrics_save_path)
|
|
|
|
def calc_generative_metrics(self, preds, target) -> Dict:
|
|
return calculate_exact_match(preds, target)
|
|
|
|
def _generative_step(self, batch: dict) -> dict:
|
|
start_time = time.time()
|
|
generated_ids = self.model.generate(
|
|
batch["input_ids"],
|
|
do_deduplication=False, # rag specific parameter
|
|
use_cache=True,
|
|
min_length=1,
|
|
max_length=self.target_lens["val"],
|
|
)
|
|
|
|
gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
|
|
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
|
target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
|
|
loss_tensors = self._step(batch)
|
|
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
|
gen_metrics: Dict = self.calc_generative_metrics(preds, target)
|
|
|
|
summ_len = np.mean(lmap(len, generated_ids))
|
|
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
|
|
return base_metrics
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
return self._generative_step(batch)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
return self.validation_epoch_end(outputs, prefix="test")
|
|
|
|
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
|
n_obs = self.n_obs[type_path]
|
|
max_target_length = self.target_lens[type_path]
|
|
dataset = Seq2SeqDataset(
|
|
self.tokenizer,
|
|
type_path=type_path,
|
|
n_obs=n_obs,
|
|
max_target_length=max_target_length,
|
|
**self.dataset_kwargs,
|
|
)
|
|
return dataset
|
|
|
|
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
|
dataset = self.get_dataset(type_path)
|
|
sampler = None
|
|
if self.hparams.sortish_sampler and type_path == "train":
|
|
assert self.hparams.gpus <= 1 # TODO: assert earlier
|
|
sampler = dataset.make_sortish_sampler(batch_size)
|
|
shuffle = False
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=dataset.collate_fn,
|
|
shuffle=shuffle,
|
|
num_workers=self.num_workers,
|
|
sampler=sampler,
|
|
)
|
|
return dataloader
|
|
|
|
def train_dataloader(self) -> DataLoader:
|
|
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
|
t_total = (
|
|
(len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus)))
|
|
// self.hparams.accumulate_grad_batches
|
|
* float(self.hparams.max_epochs)
|
|
)
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
|
|
)
|
|
if max(scheduler.get_last_lr()) > 0:
|
|
warnings.warn("All learning rates are 0")
|
|
self.lr_scheduler = scheduler
|
|
return dataloader
|
|
|
|
def val_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
|
|
|
def test_dataloader(self) -> DataLoader:
|
|
return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size)
|
|
|
|
@pl.utilities.rank_zero_only
|
|
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
|
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
|
|
self.model.config.save_step = self.step_count
|
|
self.model.save_pretrained(save_path)
|
|
self.tokenizer.save_pretrained(save_path)
|
|
|
|
@staticmethod
|
|
def add_model_specific_args(parser, root_dir):
|
|
BaseTransformer.add_model_specific_args(parser, root_dir)
|
|
add_generic_args(parser, root_dir)
|
|
parser.add_argument(
|
|
"--max_source_length",
|
|
default=128,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_target_length",
|
|
default=25,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--val_max_target_length",
|
|
default=25,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument(
|
|
"--test_max_target_length",
|
|
default=25,
|
|
type=int,
|
|
help="The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated, sequences shorter will be padded.",
|
|
)
|
|
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
|
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
|
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
|
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
|
parser.add_argument(
|
|
"--prefix",
|
|
type=str,
|
|
default=None,
|
|
help="Prefix added at the beginning of each text, typically used with T5-based models.",
|
|
)
|
|
parser.add_argument(
|
|
"--early_stopping_patience",
|
|
type=int,
|
|
default=-1,
|
|
required=False,
|
|
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
|
)
|
|
parser.add_argument(
|
|
"--distributed-port", type=int, default=-1, required=False, help="Port number for distributed training."
|
|
)
|
|
parser.add_argument(
|
|
"--model_type",
|
|
choices=["rag_sequence", "rag_token", "bart", "t5"],
|
|
type=str,
|
|
help="RAG model type: sequence or token, if none specified, the type is inferred from the model_name_or_path",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main(args, model=None) -> GenerativeQAModule:
|
|
Path(args.output_dir).mkdir(exist_ok=True)
|
|
if model is None:
|
|
model: GenerativeQAModule = GenerativeQAModule(args)
|
|
|
|
dataset = Path(args.data_dir).name
|
|
if (
|
|
args.logger_name == "default"
|
|
or args.fast_dev_run
|
|
or str(args.output_dir).startswith("/tmp")
|
|
or str(args.output_dir).startswith("/var")
|
|
):
|
|
logger = True # don't pollute wandb logs unnecessarily
|
|
elif args.logger_name == "wandb":
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
project = os.environ.get("WANDB_PROJECT", dataset)
|
|
logger = WandbLogger(name=model.output_dir.name, project=project)
|
|
|
|
elif args.logger_name == "wandb_shared":
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
|
|
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
|
|
|
es_callback = (
|
|
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
|
if args.early_stopping_patience >= 0
|
|
else False
|
|
)
|
|
trainer: pl.Trainer = generic_train(
|
|
model,
|
|
args,
|
|
logging_callback=Seq2SeqLoggingCallback(),
|
|
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
|
early_stopping_callback=es_callback,
|
|
logger=logger,
|
|
)
|
|
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
|
|
|
if not args.do_predict:
|
|
return model
|
|
|
|
model.hparams.test_checkpoint = ""
|
|
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
|
if checkpoints:
|
|
model.hparams.test_checkpoint = checkpoints[-1]
|
|
trainer.resume_from_checkpoint = checkpoints[-1] # best checkpoint
|
|
trainer.logger.log_hyperparams(model.hparams)
|
|
|
|
# test() without a model tests using the best checkpoint automatically
|
|
trainer.test()
|
|
|
|
return model
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser = pl.Trainer.add_argparse_args(parser)
|
|
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|