diff --git a/examples/research_projects/rag/callbacks_rag.py b/examples/research_projects/rag/callbacks_rag.py index ce30db88cdd..3d8425e612e 100644 --- a/examples/research_projects/rag/callbacks_rag.py +++ b/examples/research_projects/rag/callbacks_rag.py @@ -1,5 +1,4 @@ import logging -import os from pathlib import Path import numpy as np @@ -34,9 +33,10 @@ def get_checkpoint_callback(output_dir, metric): ) checkpoint_callback = ModelCheckpoint( - filepath=os.path.join(output_dir, exp), + dirpath=output_dir, + filename=exp, monitor=f"val_{metric}", - mode="max", + mode="min", save_top_k=3, period=1, # maybe save a checkpoint every time val is run, not just end of epoch. ) diff --git a/examples/research_projects/rag/distributed_ray_retriever.py b/examples/research_projects/rag/distributed_ray_retriever.py index 4ee4f963f9a..9ffc1b1e384 100644 --- a/examples/research_projects/rag/distributed_ray_retriever.py +++ b/examples/research_projects/rag/distributed_ray_retriever.py @@ -3,7 +3,6 @@ import random import ray from transformers import RagConfig, RagRetriever, RagTokenizer -from transformers.file_utils import requires_datasets, requires_faiss from transformers.models.rag.retrieval_rag import CustomHFIndex @@ -134,8 +133,6 @@ class RagRayDistributedRetriever(RagRetriever): @classmethod def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs): - requires_datasets(cls) - requires_faiss(cls) config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs) rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config) question_encoder_tokenizer = rag_tokenizer.question_encoder diff --git a/examples/research_projects/rag/finetune_rag.py b/examples/research_projects/rag/finetune_rag.py index 1a1f6772ecb..e048153c988 100644 --- a/examples/research_projects/rag/finetune_rag.py +++ b/examples/research_projects/rag/finetune_rag.py @@ -13,8 +13,8 @@ import numpy as np import pytorch_lightning as pl import torch import torch.distributed as dist -from pytorch_lightning.accelerators.ddp_accelerator import DDPAccelerator -from pytorch_lightning.cluster_environments import TorchElasticEnvironment +import torch.distributed as torch_distrib +from pytorch_lightning.plugins.training_type import DDPPlugin from torch.utils.data import DataLoader from transformers import ( @@ -36,7 +36,6 @@ if is_ray_available(): import ray from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever - from callbacks_rag import ( # noqa: E402 # isort:skipq get_checkpoint_callback, get_early_stopping_callback, @@ -74,27 +73,19 @@ class AttrDict(dict): self.__dict__ = self -# In PTL >v1.0, `init_ddp_connection` method in the `LightningModule` -# is no longer used, and is moved into DDPAccelerator instead. -# We override DDPAccelerator to add our custom logic for initializing the -# retriever. -# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/backends/test_accelerator_connector.py +class CustomDDP(DDPPlugin): + def init_ddp_connection(self, global_rank=None, world_size=None) -> None: + module = self.model + global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() + world_size = world_size if world_size is not None else self.cluster_environment.world_size() + os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + if not torch.distributed.is_initialized(): + logger.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) - -class CustomAccel(DDPAccelerator): - def __init__(self, trainer=None, **kwargs): - # Trainer is set later. - super().__init__(trainer, **kwargs) - - def init_ddp_connection(self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True): - logger.info("Custom init_ddp_connection.") - module = self.trainer.model - if self.cluster_environment is None: - self.cluster_environment = TorchElasticEnvironment() - self.distributed_port = module.hparams.distributed_port - os.environ["MASTER_PORT"] = str(self.distributed_port) - super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks) if module.is_rag_model: + self.distributed_port = module.hparams.distributed_port if module.distributed_retriever == "pytorch": module.model.rag.retriever.init_retrieval(self.distributed_port) elif module.distributed_retriever == "ray" and global_rank == 0: @@ -594,7 +585,7 @@ def main(args=None, model=None) -> GenerativeQAModule: checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), early_stopping_callback=es_callback, logger=training_logger, - accelerator=CustomAccel() if args.gpus > 1 else None, + custom_ddp_plugin=CustomDDP() if args.gpus > 1 else None, profiler=pl.profiler.AdvancedProfiler() if args.profile else None, ) pickle_save(model.hparams, model.output_dir / "hparams.pkl") diff --git a/examples/research_projects/rag/lightning_base.py b/examples/research_projects/rag/lightning_base.py index a9a05fbf960..04f82eb9e16 100644 --- a/examples/research_projects/rag/lightning_base.py +++ b/examples/research_projects/rag/lightning_base.py @@ -167,8 +167,8 @@ class BaseTransformer(pl.LightningModule): effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs - def setup(self, mode): - if mode == "test": + def setup(self, stage): + if stage == "test": self.dataset_size = len(self.test_dataloader().dataset) else: self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True) @@ -341,6 +341,7 @@ def generic_train( args: argparse.Namespace, early_stopping_callback=None, logger=True, # can pass WandbLogger() here + custom_ddp_plugin=None, extra_callbacks=[], checkpoint_callback=None, logging_callback=None, @@ -370,18 +371,17 @@ def generic_train( train_params["amp_level"] = args.fp16_opt_level if args.gpus > 1: - train_params["distributed_backend"] = "ddp" + train_params["accelerator"] = "ddp" train_params["accumulate_grad_batches"] = args.accumulate_grad_batches - train_params["accelerator"] = extra_train_kwargs.get("accelerator", None) - train_params["profiler"] = extra_train_kwargs.get("profiler", None) + train_params["profiler"] = None # extra_train_kwargs.get("profiler", None) #get unwanted logs trainer = pl.Trainer.from_argparse_args( args, weights_summary=None, - callbacks=[logging_callback] + extra_callbacks, + callbacks=[logging_callback] + extra_callbacks + [checkpoint_callback], + plugins=[custom_ddp_plugin], logger=logger, - checkpoint_callback=checkpoint_callback, **train_params, ) diff --git a/examples/research_projects/rag/requirements.txt b/examples/research_projects/rag/requirements.txt index 639ebf12d27..ef065e36e1c 100644 --- a/examples/research_projects/rag/requirements.txt +++ b/examples/research_projects/rag/requirements.txt @@ -3,5 +3,5 @@ datasets >= 1.0.1 psutil >= 5.7.0 torch >= 1.4.0 transformers -pytorch-lightning==1.0.4 -GitPython +pytorch-lightning==1.3.1 +GitPython \ No newline at end of file