updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)

* updated the original RAG implementation to be compatible with the latest PL version

* updated the requirements.txt file

* execute make style

* code quality test

* code quality

* conflix resolved in requirement.txt

* code quality

* changed the MyDDP class name to CustomDDP
This commit is contained in:
Shamane Siri 2021-06-09 00:42:49 +12:00 committed by GitHub
parent 70f88eeccc
commit e33085d648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 38 deletions

View File

@ -1,5 +1,4 @@
import logging
import os
from pathlib import Path
import numpy as np
@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
)
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join(output_dir, exp),
dirpath=output_dir,
filename=exp,
monitor=f"val_{metric}",
mode="max",
mode="min",
save_top_k=3,
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
)

View File

@ -3,7 +3,6 @@ import random
import ray
from transformers import RagConfig, RagRetriever, RagTokenizer
from transformers.file_utils import requires_datasets, requires_faiss
from transformers.models.rag.retrieval_rag import CustomHFIndex
@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever):
@classmethod
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
requires_datasets(cls)
requires_faiss(cls)
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
question_encoder_tokenizer = rag_tokenizer.question_encoder

View File

@ -13,8 +13,8 @@ import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed as dist
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
import torch.distributed as torch_distrib
from pytorch_lightning.plugins.training_type import DDPPlugin
from torch.utils.data import DataLoader
from transformers import (
@ -36,7 +36,6 @@ if is_ray_available():
import ray
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
from callbacks_rag import ( # noqa: E402 # isort:skipq
get_checkpoint_callback,
get_early_stopping_callback,
@ -74,27 +73,19 @@ class AttrDict(dict):
self.__dict__ = self
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
# is no longer used, and is moved into DDPAccelerator instead.
# We override DDPAccelerator to add our custom logic for initializing the
# retriever.
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
class CustomDDP(DDPPlugin):
def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
module = self.model
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
if not torch.distributed.is_initialized():
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
class CustomAccel(DDPAccelerator):
def __init__(self, trainer=None, **kwargs):
# Trainer is set later.
super().__init__(trainer, **kwargs)
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
logger.info("Custom init_ddp_connection.")
module = self.trainer.model
if self.cluster_environment is None:
self.cluster_environment = TorchElasticEnvironment()
self.distributed_port = module.hparams.distributed_port
os.environ["MASTER_PORT"] = str(self.distributed_port)
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
if module.is_rag_model:
self.distributed_port = module.hparams.distributed_port
if module.distributed_retriever == "pytorch":
module.model.rag.retriever.init_retrieval(self.distributed_port)
elif module.distributed_retriever == "ray" and global_rank == 0:
@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
early_stopping_callback=es_callback,
logger=training_logger,
accelerator=CustomAccel() if args.gpus > 1 else None,
custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
)
pickle_save(model.hparams, model.output_dir / "hparams.pkl")

View File

@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
def setup(self, mode):
if mode == "test":
def setup(self, stage):
if stage == "test":
self.dataset_size = len(self.test_dataloader().dataset)
else:
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
@ -341,6 +341,7 @@ def generic_train(
args: argparse.Namespace,
early_stopping_callback=None,
logger=True, # can pass WandbLogger() here
custom_ddp_plugin=None,
extra_callbacks=[],
checkpoint_callback=None,
logging_callback=None,
@ -370,18 +371,17 @@ def generic_train(
train_params["amp_level"] = args.fp16_opt_level
if args.gpus > 1:
train_params["distributed_backend"] = "ddp"
train_params["accelerator"] = "ddp"
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
trainer = pl.Trainer.from_argparse_args(
args,
weights_summary=None,
callbacks=[logging_callback] + extra_callbacks,
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
plugins=[custom_ddp_plugin],
logger=logger,
checkpoint_callback=checkpoint_callback,
**train_params,
)

View File

@ -3,5 +3,5 @@ datasets >= 1.0.1
psutil >= 5.7.0
torch >= 1.4.0
transformers
pytorch-lightning==1.0.4
GitPython
pytorch-lightning==1.3.1
GitPython