diff --git a/docs/source/en/main_classes/deepspeed.mdx b/docs/source/en/main_classes/deepspeed.mdx index 7926ddb5c68..ba1caee1934 100644 --- a/docs/source/en/main_classes/deepspeed.mdx +++ b/docs/source/en/main_classes/deepspeed.mdx @@ -162,33 +162,24 @@ If after trying everything suggested you still encounter build issues, please, p ### Deployment with multiple GPUs -To deploy this feature with multiple GPUs adjust the [`Trainer`] command line arguments as -following: - -1. replace `python -m torch.distributed.launch` with `deepspeed`. -2. add a new argument `--deepspeed ds_config.json`, where `ds_config.json` is the DeepSpeed configuration file as +To deploy the DeepSpeed integration adjust the [`Trainer`] command line arguments to include a new argument `--deepspeed ds_config.json`, where `ds_config.json` is the DeepSpeed configuration file as documented [here](https://www.deepspeed.ai/docs/config-json/). The file naming is up to you. -Therefore, if your original command line looked as follows: +You can use a launcher of your choice here. You can continue using the pytorch launcher: ```bash -python -m torch.distributed.launch --nproc_per_node=2 your_program.py +torch.distributed.run --nproc_per_node=2 your_program.py --deepspeed ds_config.json ``` - -Now it should be: +or use the launcher provided by `deepspeed`: ```bash deepspeed --num_gpus=2 your_program.py --deepspeed ds_config.json ``` -Unlike, `torch.distributed.launch` where you have to specify how many GPUs to use with `--nproc_per_node`, with the -`deepspeed` launcher you don't have to use the corresponding `--num_gpus` if you want all of your GPUs used. The +As you can see the arguments aren't the same, but for most needs either of them works. The full details on how to configure various nodes and GPUs can be found [here](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). -In fact, you can continue using `-m torch.distributed.launch` with DeepSpeed as long as you don't need to use -`deepspeed` launcher-specific arguments. Typically if you don't need a multi-node setup you're not required to use -the `deepspeed` launcher. But since in the DeepSpeed documentation it'll be used everywhere, for consistency we will -use it here as well. +When you use the `deepspeed` launcher and you want to use all available gpus you can just omit the `--num_gpus` flag. Here is an example of running `run_translation.py` under DeepSpeed deploying all available GPUs: @@ -282,6 +273,95 @@ Notes: + + +### Deployment with multiple Nodes + +The information in this section isn't not specific to the DeepSpeed integration and is applicable to any multi-node program. But DeepSpeed provides a `deepspeed` launcher that is easier to use than other launchers unless you are in a SLURM environment. + +For the duration of this section let's assume that you have 2 nodes with 8 gpus each. And you can reach the first node with `ssh hostname1` and second node with `ssh hostname2`, and both must be able to reach each other via ssh locally without a password. Of course, you will need to rename these host (node) names to the actual host names you are working with. + +#### The torch.distributed.run launcher + + +For example, to use `torch.distributed.run`, you could do: + +```bash +python -m torch.distributed.run --nproc_per_node=8 --nnode=2 --node_rank=0 --master_addr=hostname1 \ +--master_port=9901 your_program.py --deepspeed ds_config.json +``` + +You have to ssh to each node and run this same command on each one of them! There is no rush, the launcher will wait until both nodes will synchronize. + +For more information please see [torchrun](https://pytorch.org/docs/stable/elastic/run.html). Incidentally, this is also the launcher that replaced `torch.distributed.launch` a few pytorch versions back. + + +#### The deepspeed launcher + +To use the `deepspeed` launcher instead, you have to first create a `hostfile` file: + +``` +hostname1 slots=8 +hostname2 slots=8 +``` +and then you can launch it as: + +```bash +deepspeed --num_gpus 8 --num_nodes 2 --hostfile hostfile --master_addr hostname1 --master_port=9901 \ +your_program.py --deepspeed ds_config.json +``` + +Unlike the `torch.distributed.run` launcher, `deepspeed` will automatically launch this command on both nodes! + +For more information please see [Resource Configuration (multi-node)](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). + + +#### Launching in a SLURM environment + +In the SLURM environment the following approach can be used. The following is a slurm script `launch.slurm` which you will need to adapt it to your specific SLURM environment. + +```bash +#SBATCH --job-name=test-nodes # name +#SBATCH --nodes=2 # nodes +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=10 # number of cores per tasks +#SBATCH --gres=gpu:8 # number of gpus +#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS) +#SBATCH --output=%x-%j.out # output file name + +export GPUS_PER_NODE=8 +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=9901 + +srun --jobid $SLURM_JOBID bash -c 'python -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \ + --master_addr $MASTER_ADDR --master_port $MASTER_PORT \ +your_program.py --deepspeed ds_config.json' +``` + +All is left is to schedule it to run: +```bash +sbatch launch.slurm +``` + +`srun` will take care of launching the program simultaneously on all nodes. + + +#### Use of Non-shared filesystem + +By default DeepSpeed expects that a multi-node environment uses a shared storage. If this is not the case and each node can only see the local filesystem, you need to adjust the config file to include a [`checkpoint`_section](https://www.deepspeed.ai/docs/config-json/#checkpoint-options) with the following setting: + +```json +{ + "checkpoint": { + "use_node_local_storage": true + } +} +``` + +Alternatively, you can also use the [`Trainer`]'s `--save_on_each_node` argument, and the above config will be added automatically for you. + + ### Deployment in Notebooks diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 037c5985f88..01987fc758a 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -141,6 +141,11 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): else: fp16_backend = None + if args.save_on_each_node: + # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True + self.config["checkpoint"] = self.config.get("checkpoint", {}) + self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node + # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set # any here unless the user did the work self.fill_match(