mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
updated the original RAG implementation to be compatible with latest Pytorch-Lightning (#11806)
* updated the original RAG implementation to be compatible with the latest PL version * updated the requirements.txt file * execute make style * code quality test * code quality * conflix resolved in requirement.txt * code quality * changed the MyDDP class name to CustomDDP
This commit is contained in:
parent
70f88eeccc
commit
e33085d648
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.path.join(output_dir, exp),
|
||||
dirpath=output_dir,
|
||||
filename=exp,
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
mode="min",
|
||||
save_top_k=3,
|
||||
period=1, # maybe save a checkpoint every time val is run, not just end of epoch.
|
||||
)
|
||||
|
@ -3,7 +3,6 @@ import random
|
||||
|
||||
import ray
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.file_utils import requires_datasets, requires_faiss
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
|
||||
@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
||||
requires_datasets(cls)
|
||||
requires_faiss(cls)
|
||||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||
|
@ -13,8 +13,8 @@ import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator
|
||||
from pytorch_lightning.cluster_environments import TorchElasticEnvironment
|
||||
import torch.distributed as torch_distrib
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from transformers import (
|
||||
@ -36,7 +36,6 @@ if is_ray_available():
|
||||
import ray
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
||||
|
||||
|
||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
get_checkpoint_callback,
|
||||
get_early_stopping_callback,
|
||||
@ -74,27 +73,19 @@ class AttrDict(dict):
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule`
|
||||
# is no longer used, and is moved into DDPAccelerator instead.
|
||||
# We override DDPAccelerator to add our custom logic for initializing the
|
||||
# retriever.
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py
|
||||
class CustomDDP(DDPPlugin):
|
||||
def init_ddp_connection(self, global_rank=None, world_size=None) -> None:
|
||||
module = self.model
|
||||
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
|
||||
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
|
||||
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
|
||||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
||||
if not torch.distributed.is_initialized():
|
||||
logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
|
||||
|
||||
|
||||
class CustomAccel(DDPAccelerator):
|
||||
def __init__(self, trainer=None, **kwargs):
|
||||
# Trainer is set later.
|
||||
super().__init__(trainer, **kwargs)
|
||||
|
||||
def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True):
|
||||
logger.info("Custom init_ddp_connection.")
|
||||
module = self.trainer.model
|
||||
if self.cluster_environment is None:
|
||||
self.cluster_environment = TorchElasticEnvironment()
|
||||
self.distributed_port = module.hparams.distributed_port
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if module.is_rag_model:
|
||||
self.distributed_port = module.hparams.distributed_port
|
||||
if module.distributed_retriever == "pytorch":
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
elif module.distributed_retriever == "ray" and global_rank == 0:
|
||||
@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=training_logger,
|
||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||
custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None,
|
||||
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule):
|
||||
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
||||
return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs
|
||||
|
||||
def setup(self, mode):
|
||||
if mode == "test":
|
||||
def setup(self, stage):
|
||||
if stage == "test":
|
||||
self.dataset_size = len(self.test_dataloader().dataset)
|
||||
else:
|
||||
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
||||
@ -341,6 +341,7 @@ def generic_train(
|
||||
args: argparse.Namespace,
|
||||
early_stopping_callback=None,
|
||||
logger=True, # can pass WandbLogger() here
|
||||
custom_ddp_plugin=None,
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
@ -370,18 +371,17 @@ def generic_train(
|
||||
train_params["amp_level"] = args.fp16_opt_level
|
||||
|
||||
if args.gpus > 1:
|
||||
train_params["distributed_backend"] = "ddp"
|
||||
train_params["accelerator"] = "ddp"
|
||||
|
||||
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
||||
train_params["accelerator"] = extra_train_kwargs.get("accelerator", None)
|
||||
train_params["profiler"] = extra_train_kwargs.get("profiler", None)
|
||||
train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
weights_summary=None,
|
||||
callbacks=[logging_callback] + extra_callbacks,
|
||||
callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback],
|
||||
plugins=[custom_ddp_plugin],
|
||||
logger=logger,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
**train_params,
|
||||
)
|
||||
|
||||
|
@ -3,5 +3,5 @@ datasets >= 1.0.1
|
||||
psutil >= 5.7.0
|
||||
torch >= 1.4.0
|
||||
transformers
|
||||
pytorch-lightning==1.0.4
|
||||
GitPython
|
||||
pytorch-lightning==1.3.1
|
||||
GitPython
|
Loading…
Reference in New Issue
Block a user