mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[DeepSpeed] ZeRO Stage 3 (#10753)
* synced gpus * fix * fix * need to use t5-small for quality tests * notes * complete merge * fix a disappearing std stream problem * start zero3 tests * wip * tune params * sorting out the pre-trained model loading * reworking generate loop wip * wip * style * fix tests * split the tests * refactor tests * wip * parameterized * fix * workout the resume from non-ds checkpoint pass + test * cleanup * remove no longer needed code * split getter/setter functions * complete the docs * suggestions * gpus and their compute capabilities link * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * style * remove invalid paramgd * automatically configure zero3 params that rely on hidden size * make _get_resized_embeddings zero3-aware * add test exercising resize_token_embeddings() * add docstring Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
acc851e1ff
commit
c6d664849b
@ -134,6 +134,8 @@ Toward Training Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, O
|
||||
|
||||
This provided support is new and experimental as of this writing.
|
||||
|
||||
.. _zero-install-notes:
|
||||
|
||||
Installation Notes
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
@ -156,7 +158,8 @@ please, read the following notes first.
|
||||
In these notes we give examples for what to do when ``pytorch`` has been built with CUDA ``10.2``. If your situation is
|
||||
different remember to adjust the version number to the one you are after.
|
||||
|
||||
**Possible problem #1:**
|
||||
Possible problem #1
|
||||
=======================================================================================================================
|
||||
|
||||
While, Pytorch comes with its own CUDA toolkit, to build these two projects you must have an identical version of CUDA
|
||||
installed system-wide.
|
||||
@ -176,7 +179,8 @@ If you don't have CUDA installed system-wide, install it first. You will find th
|
||||
search engine. For example, if you're on Ubuntu you may want to search for: `ubuntu cuda 10.2 install
|
||||
<https://www.google.com/search?q=ubuntu+cuda+10.2+install>`__.
|
||||
|
||||
**Possible problem #2:**
|
||||
Possible problem #2
|
||||
=======================================================================================================================
|
||||
|
||||
Another possible common problem is that you may have more than one CUDA toolkit installed system-wide. For example you
|
||||
may have:
|
||||
@ -222,7 +226,8 @@ exist. ``lib64`` sub-directory is where the various CUDA ``.so`` objects, like `
|
||||
that your system will have it named differently, but if it is adjust it to reflect your reality.
|
||||
|
||||
|
||||
**Possible problem #3:**
|
||||
Possible problem #3
|
||||
=======================================================================================================================
|
||||
|
||||
Some older CUDA versions may refuse to build with newer compilers. For example, you my have ``gcc-9`` but it wants
|
||||
``gcc-7``.
|
||||
@ -247,13 +252,6 @@ should find ``gcc-7`` (and ``g++7``) and then the build will succeed.
|
||||
|
||||
As always make sure to edit the paths in the example to match your situation.
|
||||
|
||||
**If still unsuccessful:**
|
||||
|
||||
If after addressing these you still encounter build issues, please, proceed with the GitHub Issue of `FairScale
|
||||
<https://github.com/facebookresearch/fairscale/issues>`__ and `Deepspeed
|
||||
<https://github.com/microsoft/DeepSpeed/issues>`__, depending on the project you have the problem with.
|
||||
|
||||
|
||||
FairScale
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
@ -267,20 +265,66 @@ provides support for the following features from `the ZeRO paper <https://arxiv.
|
||||
|
||||
You will need at least two GPUs to use this feature.
|
||||
|
||||
To deploy this feature:
|
||||
|
||||
1. Install the library via pypi:
|
||||
**Installation**:
|
||||
|
||||
.. code-block:: bash
|
||||
Install the library via pypi:
|
||||
|
||||
pip install fairscale
|
||||
.. code-block:: bash
|
||||
|
||||
or find more details on `the FairScale's GitHub page
|
||||
<https://github.com/facebookresearch/fairscale/#installation>`__.
|
||||
pip install fairscale
|
||||
|
||||
2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments,
|
||||
and make sure you have added the distributed launcher ``-m torch.distributed.launch
|
||||
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||
or find more details on `the FairScale's GitHub page <https://github.com/facebookresearch/fairscale/#installation>`__.
|
||||
|
||||
If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`.
|
||||
|
||||
If it's still not resolved the build issue, here are a few more ideas.
|
||||
|
||||
``fairscale`` seems to have an issue with the recently introduced by pip build isolation feature. If you have a problem
|
||||
with it, you may want to try one of:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install fairscale --no-build-isolation .
|
||||
|
||||
or:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/facebookresearch/fairscale/
|
||||
cd fairscale
|
||||
rm -r dist build
|
||||
python setup.py bdist_wheel
|
||||
pip uninstall -y fairscale
|
||||
pip install dist/fairscale-*.whl
|
||||
|
||||
``fairscale`` also has issues with building against pytorch-nightly, so if you use it you may have to try one of:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip uninstall -y fairscale; pip install fairscale --pre \
|
||||
-f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html \
|
||||
--no-cache --no-build-isolation
|
||||
|
||||
or:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -v --disable-pip-version-check . \
|
||||
-f https://download.pytorch.org/whl/nightly/cu110/torch_nightly.html --pre
|
||||
|
||||
Of course, adjust the urls to match the cuda version you use.
|
||||
|
||||
If after trying everything suggested you still encounter build issues, please, proceed with the GitHub Issue of
|
||||
`FairScale <https://github.com/facebookresearch/fairscale/issues>`__.
|
||||
|
||||
|
||||
|
||||
**Usage**:
|
||||
|
||||
To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments, and
|
||||
make sure you have added the distributed launcher ``-m torch.distributed.launch
|
||||
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||
|
||||
For example here is how you could use it for ``run_translation.py`` with 2 GPUs:
|
||||
|
||||
@ -346,19 +390,23 @@ DeepSpeed
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
`DeepSpeed <https://github.com/microsoft/DeepSpeed>`__ implements everything described in the `ZeRO paper
|
||||
<https://arxiv.org/abs/1910.02054>`__, except ZeRO's stage 3. "Parameter Partitioning (Pos+g+p)". Currently it provides
|
||||
full support for:
|
||||
<https://arxiv.org/abs/1910.02054>`__. Currently it provides full support for:
|
||||
|
||||
1. Optimizer State Partitioning (ZeRO stage 1)
|
||||
2. Add Gradient Partitioning (ZeRO stage 2)
|
||||
3. Custom fp16 handling
|
||||
4. A range of fast Cuda-extension-based Optimizers
|
||||
5. ZeRO-Offload
|
||||
2. Gradient Partitioning (ZeRO stage 2)
|
||||
3. Param Partitioning (ZeRO stage 3)
|
||||
4. Custom mixed precision training handling
|
||||
5. A range of fast CUDA-extension-based Optimizers
|
||||
6. ZeRO-Offload
|
||||
|
||||
ZeRO-Offload has its own dedicated paper: `ZeRO-Offload: Democratizing Billion-Scale Model Training
|
||||
<https://arxiv.org/abs/2101.06840>`__.
|
||||
|
||||
DeepSpeed is currently used only for training, as all the currently available features are of no use to inference.
|
||||
DeepSpeed ZeRO-2 is currently used only for training, as all the currently available features are of no use to
|
||||
inference.
|
||||
|
||||
DeepSpeed ZeRO-3 can be used for inference as well, since it allows huge models to be loaded on multiple GPUs, which
|
||||
won't be possible on a single GPU.
|
||||
|
||||
|
||||
|
||||
@ -371,7 +419,74 @@ Install the library via pypi:
|
||||
|
||||
pip install deepspeed
|
||||
|
||||
or find more details on `the DeepSpeed's GitHub page <https://github.com/microsoft/deepspeed#installation>`__.
|
||||
or find more details on `the DeepSpeed's GitHub page <https://github.com/microsoft/deepspeed#installation>`__ and
|
||||
`advanced install <https://www.deepspeed.ai/tutorials/advanced-install/>`__.
|
||||
|
||||
If you're still struggling with the build, first make sure to read :ref:`zero-install-notes`.
|
||||
|
||||
If you don't prebuild the extensions and rely on them to be built at run time and you tried all of the above solutions
|
||||
to no avail, the next thing to try is to pre-build the modules before installing them.
|
||||
|
||||
To make a local build for DeepSpeed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/microsoft/DeepSpeed/
|
||||
cd DeepSpeed
|
||||
rm -rf build
|
||||
TORCH_CUDA_ARCH_LIST="6.1;8.6" DS_BUILD_OPS=1 pip install . \
|
||||
--global-option="build_ext" --global-option="-j8" --no-cache -v \
|
||||
--disable-pip-version-check 2>&1 | tee build.log
|
||||
|
||||
Edit ``TORCH_CUDA_ARCH_LIST`` to insert the code for the architectures of the GPU cards you intend to use.
|
||||
|
||||
Or if you need to use the same setup on multiple machines, make a binary wheel:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
git clone https://github.com/microsoft/DeepSpeed/
|
||||
cd DeepSpeed
|
||||
rm -rf build
|
||||
TORCH_CUDA_ARCH_LIST="6.1;8.6" DS_BUILD_OPS=1 \
|
||||
python setup.py build_ext -j8 bdist_wheel
|
||||
|
||||
it will generate something like ``dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`` which now you can install
|
||||
as ``pip install deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`` locally or on any other machine.
|
||||
|
||||
Again, remember to ensure to adjust ``TORCH_CUDA_ARCH_LIST`` to the target architectures.
|
||||
|
||||
You can find the complete list of NVIDIA GPUs and their corresponding **Compute Capabilities** (same as arch in this
|
||||
context) `here <https://developer.nvidia.com/cuda-gpus>`__.
|
||||
|
||||
You can check the archs pytorch was built with using:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -c "import torch; print(torch.cuda.get_arch_list())"
|
||||
|
||||
Here is how to find out the arch for one of the installed GPU. For example, for GPU 0:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python -c "import torch; \
|
||||
print(torch.cuda.get_device_properties(torch.device('cuda')))"
|
||||
|
||||
If the output is:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
_CudaDeviceProperties(name='GeForce RTX 3090', major=8, minor=6, total_memory=24268MB, multi_processor_count=82)
|
||||
|
||||
then you know that this card's arch is ``8.6``.
|
||||
|
||||
You can also leave ``TORCH_CUDA_ARCH_LIST`` out completely and then the build program will automatically query the
|
||||
architecture of the GPUs the build is made on. This may or may not match the GPUs on the target machines, that's why
|
||||
it's best to specify the desired archs explicitly.
|
||||
|
||||
If after trying everything suggested you still encounter build issues, please, proceed with the GitHub Issue of
|
||||
`Deepspeed <https://github.com/microsoft/DeepSpeed/issues>`__,
|
||||
|
||||
|
||||
|
||||
Deployment with multiple GPUs
|
||||
=======================================================================================================================
|
||||
@ -498,7 +613,7 @@ Deployment in Notebooks
|
||||
The problem with running notebook cells as a script is that there is no normal ``deepspeed`` launcher to rely on, so
|
||||
under certain setups we have to emulate it.
|
||||
|
||||
Here is how you'd have to adjust your training code in the notebook to use DeepSpeed.
|
||||
If you're using only 1 GPU, here is how you'd have to adjust your training code in the notebook to use DeepSpeed.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -516,7 +631,11 @@ Here is how you'd have to adjust your training code in the notebook to use DeepS
|
||||
trainer = Trainer(...)
|
||||
trainer.train()
|
||||
|
||||
Note: `...` stands for the normal arguments that you'd pass to the functions.
|
||||
Note: ``...`` stands for the normal arguments that you'd pass to the functions.
|
||||
|
||||
If you want to use more than 1 GPU, you must use a multi-process environment for DeepSpeed to work. That is, you have
|
||||
to use the launcher for that purpose and this cannot be accomplished by emulating the distributed environment presented
|
||||
at the beginning of this section.
|
||||
|
||||
If you want to create the config file on the fly in the notebook in the current directory, you could have a dedicated
|
||||
cell with:
|
||||
@ -570,22 +689,30 @@ cell with:
|
||||
EOT
|
||||
|
||||
|
||||
That's said if the script is not in the notebook cells, you can launch ``deepspeed`` normally via shell from a cell
|
||||
with:
|
||||
If the training script is in a normal file and not in the notebook cells, you can launch ``deepspeed`` normally via
|
||||
shell from a cell. For example, to use ``run_translation.py`` you would launch it with:
|
||||
|
||||
.. code-block::
|
||||
|
||||
!deepspeed examples/seq2seq/run_translation.py ...
|
||||
!git clone https://github.com/huggingface/transformers
|
||||
!cd transformers; deepspeed examples/seq2seq/run_translation.py ...
|
||||
|
||||
or with bash magic, where you can write a multi-line code for the shell to run:
|
||||
or with ``%%bash`` magic, where you can write a multi-line code for the shell program to run:
|
||||
|
||||
.. code-block::
|
||||
|
||||
%%bash
|
||||
|
||||
cd /somewhere
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
deepspeed examples/seq2seq/run_translation.py ...
|
||||
|
||||
In such case you don't need any of the code presented at the beginning of this section.
|
||||
|
||||
Note: ``%%bash`` magic is neat, but currently it buffers the output so you won't see the logs until the process
|
||||
completes.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -717,26 +844,45 @@ Of course, you will need to adjust the values in this example to your situation.
|
||||
ZeRO
|
||||
=======================================================================================================================
|
||||
|
||||
`Zero Redundancy Optimizer (ZeRO) <https://www.deepspeed.ai/tutorials/zero/>`__ is the work horse of DeepSpeed. It
|
||||
support 3 different levels (stages) of optimization. The first one is not quite interesting for scalability purposes,
|
||||
therefore this document focuses on stages 2 and 3. You will find more indepth information in the DeepSpeed
|
||||
documentation.
|
||||
|
||||
The ``zero_optimization`` section of the configuration file is the most important part (`docs
|
||||
<https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training>`__), since that is where you define
|
||||
which ZeRO stages you want to enable and how to configure them.
|
||||
which ZeRO stages you want to enable and how to configure them. You will find the explanation for each parameter in the
|
||||
DeepSpeed docs.
|
||||
|
||||
This section has to be configured exclusively via DeepSpeed configuration - the :class:`~transformers.Trainer` provides
|
||||
no equivalent command line arguments.
|
||||
|
||||
Note: currently DeepSpeed doesn't validate parameter names, so if you misspell any, it'll use the default setting for
|
||||
the parameter that got misspelled. You can watch the DeepSpeed engine start up log messages to see what values it is
|
||||
going to use.
|
||||
|
||||
|
||||
ZeRO-2 Config
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
The following is an example configuration for ZeRO stage 2:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
}
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
}
|
||||
}
|
||||
|
||||
Notes:
|
||||
**Performance tuning:**
|
||||
|
||||
- enabling ``cpu_offload`` should reduce GPU RAM usage (it requires ``"stage": 2``)
|
||||
- ``"overlap_comm": true`` trades off increased GPU RAM usage to lower all-reduce latency. ``overlap_comm`` uses 4.5x
|
||||
@ -748,9 +894,217 @@ Notes:
|
||||
the slower the communication, and the more GPU RAM will be available to other tasks. So if a bigger batch size is
|
||||
important, getting a slightly slower training time could be a good trade.
|
||||
|
||||
This section has to be configured exclusively via DeepSpeed configuration - the :class:`~transformers.Trainer` provides
|
||||
no equivalent command line arguments.
|
||||
|
||||
ZeRO-3 Config
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
The following is an example configuration for ZeRO stage 3:
|
||||
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"cpu_offload": true,
|
||||
"cpu_offload_params": true,
|
||||
"cpu_offload_use_pin_memory" : true,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e14,
|
||||
"reduce_bucket_size": 1e6,
|
||||
"stage3_prefetch_bucket_size": 0.94e6,
|
||||
"stage3_param_persistence_threshold": 1e4,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
|
||||
Note: if you're migrating from ZeRO-2 configuration that: ``allgather_partitions``, ``allgather_bucket_size`` and
|
||||
``reduce_scatter`` configuration parameters are not used in ZeRO-3. If you keep these they will just be ignored.
|
||||
|
||||
**Performance tuning:**
|
||||
|
||||
- ``sub_group_size``: ``1e14``
|
||||
- ``reduce_bucket_size``: ``hidden_size*hidden_size``
|
||||
- ``stage3_prefetch_bucket_size``: ``0.9 * hidden_size * hidden_size``
|
||||
- ``stage3_param_persistence_threshold``: ``10 * hidden_size``
|
||||
- ``stage3_max_live_parameters``: ``1e9``
|
||||
- ``stage3_max_reuse_distance``: ``1e9``
|
||||
|
||||
If hitting OOM reduce ``stage3_max_live_parameters`` and ``stage3_max_reuse_distance``. They should have minimal impact
|
||||
on performance unless you are doing activation checkpointing. ``1e9`` would consume ~2GB. The memory is shared by
|
||||
``stage3_max_live_parameters`` and ``stage3_max_reuse_distance``, so its not additive, its just 2GB total.
|
||||
|
||||
``stage3_max_live_parameters`` is the upper limit on how many full parameters you want to keep on the GPU at any given
|
||||
time. "reuse distance" is a metric we are using to figure out when will a parameter be used again in the future, and we
|
||||
use the ``stage3_max_reuse_distance`` to decide whether to throw away the parameter or to keep it. If a parameter is
|
||||
going to be used again in near future (less than ``stage3_max_reuse_distance``) then we keep it to reduce communication
|
||||
overhead. This is super helpful when you have activation checkpointing enabled, where we do a forward recompute and
|
||||
backward passes a a single layer granularity and want to keep the parameter in the forward recompute till the backward
|
||||
|
||||
If you set ``reduce_bucket_size``, ``stage3_prefetch_bucket_size`` and ``stage3_param_persistence_threshold`` as
|
||||
recommended above, they will already be fairly small so you won't have to tune those much.
|
||||
|
||||
Since ``hidden_size`` varies from model to model, the ``Trainer`` will automatically set the needed value for the 3
|
||||
config parameters that contain that variable (using ``model.config.hidden_size``). Just set these values to ``0`` as
|
||||
shown below and the right configuration will be passed to DeepSpeed:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"cpu_offload": true,
|
||||
"cpu_offload_params": true,
|
||||
"cpu_offload_use_pin_memory" : true,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e14,
|
||||
"reduce_bucket_size": 0,
|
||||
"stage3_prefetch_bucket_size": 0,
|
||||
"stage3_param_persistence_threshold": 0,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
}
|
||||
}
|
||||
|
||||
``stage3_gather_fp16_weights_on_model_save`` enables model fp16 weights consolidation when model gets saved. With large
|
||||
models and multiple GPUs this is an expensive operation both in terms of memory and speed. It's currently required if
|
||||
you plan to resume the training. Watch out for future updates that will remove this limitation and make things more
|
||||
flexible.
|
||||
|
||||
|
||||
ZeRO-2 vs ZeRO-3 Performance
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
ZeRO-3 is likely to be slower than ZeRO-2 if everything else is configured the same because the former has to gather
|
||||
model weights in addition to what ZeRO-2 does. If ZeRO-2 meets your needs and you don't need to scale beyond a few GPUs
|
||||
then you may choose to stick to it. It's important to understand that ZeRO-3 enables a much higher scalability capacity
|
||||
at a cost of speed.
|
||||
|
||||
It's possible to adjust ZeRO-3 configuration to make it perform closer to ZeRO-2:
|
||||
|
||||
- set ``stage3_param_persistence_threshold`` to a very large number - larger than the largest parameter, e.g., ``6 *
|
||||
hidden_size * hidden_size``. This will keep the parameters on the GPUs.
|
||||
- turn off ``cpu_offload_params`` since ZeRO-2 doesn't have that option.
|
||||
|
||||
The performance will likely improve significantly with just ``cpu_offload_params`` turned off, even if you don't change
|
||||
``stage3_param_persistence_threshold``. Of course, these changes will impact the size of the model you can train. So
|
||||
these help you to trade scalability for speed depending on your needs.
|
||||
|
||||
|
||||
|
||||
ZeRO-2 Example
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
Here is a full ZeRO-2 all-enabled configuration file ``ds_config_zero2.json``:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [0.8, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 3e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
|
||||
|
||||
|
||||
ZeRO-3 Example
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
Here is a full ZeRO-3 all-enabled configuration file ``ds_config_zero3.json``:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"cpu_offload": true,
|
||||
"cpu_offload_params": true,
|
||||
"cpu_offload_use_pin_memory" : true,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e14,
|
||||
"reduce_bucket_size": 1e6,
|
||||
"stage3_prefetch_bucket_size": 0.94e6,
|
||||
"stage3_param_persistence_threshold": 1e4,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [0.8, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 3e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
|
||||
|
||||
Optimizer and Scheduler
|
||||
@ -772,7 +1126,7 @@ If ``cpu_offload`` is enabled you must use both DeepSpeed scheduler and DeepSpee
|
||||
|
||||
|
||||
Optimizer
|
||||
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
|
||||
DeepSpeed's main optimizers are Adam, AdamW, OneBitAdam, and Lamb. These have been thoroughly tested with ZeRO and are
|
||||
@ -818,7 +1172,7 @@ make sure to adjust the values. e.g. if use Adam you will want ``weight_decay``
|
||||
|
||||
|
||||
Scheduler
|
||||
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
DeepSpeed supports LRRangeTest, OneCycle, WarmupLR and WarmupDecayLR LR schedulers. The full documentation is `here
|
||||
<https://www.deepspeed.ai/docs/config-json/#scheduler-parameters>`__.
|
||||
@ -886,11 +1240,7 @@ and ``warmup_max_lr``, ``warmup_num_steps`` and ``total_num_steps`` will be corr
|
||||
Automatic Mixed Precision
|
||||
=======================================================================================================================
|
||||
|
||||
You can work with FP16 in one of the following ways:
|
||||
|
||||
1. Pytorch native amp, as documented `here <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`__.
|
||||
2. NVIDIA's apex, as documented `here
|
||||
<https://www.deepspeed.ai/docs/config-json/#automatic-mixed-precision-amp-training-options>`__.
|
||||
You can use automatic mixed precision with either a pytorch-like AMP way or the apex-like way:
|
||||
|
||||
If you want to use an equivalent of the Pytorch native amp, you can either configure the ``fp16`` entry in the
|
||||
configuration file, or use the following command line arguments: ``--fp16 --fp16_backend amp``.
|
||||
@ -909,6 +1259,8 @@ Here is an example of the ``fp16`` configuration:
|
||||
},
|
||||
}
|
||||
|
||||
Here is the `documentation <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`__.
|
||||
|
||||
If you want to use NVIDIA's apex instead, you can can either configure the ``amp`` entry in the configuration file, or
|
||||
use the following command line arguments: ``--fp16 --fp16_backend apex --fp16_opt_level 01``.
|
||||
|
||||
@ -923,6 +1275,9 @@ Here is an example of the ``amp`` configuration:
|
||||
}
|
||||
}
|
||||
|
||||
Here is the `documentation
|
||||
<https://www.deepspeed.ai/docs/config-json/#automatic-mixed-precision-amp-training-options>`__.
|
||||
|
||||
|
||||
Gradient Accumulation
|
||||
=======================================================================================================================
|
||||
@ -935,12 +1290,12 @@ While normally DeepSpeed gets gradient accumulation configured with:
|
||||
"gradient_accumulation_steps": 3,
|
||||
}
|
||||
|
||||
in this case, to enable gradient accumulation, pass the command line `--gradient_accumulation_steps` argument as normal
|
||||
and it will get injected into the DeepSpeed configuration.
|
||||
in this case, to enable gradient accumulation, pass the command line ``--gradient_accumulation_steps 3`` argument as
|
||||
normal and it will get injected into the DeepSpeed configuration.
|
||||
|
||||
If you try to add it directly to the configuration file, you will receive an error from the Trainer - this is because
|
||||
this setting is needed by the Trainer too, and so this approach ensures that there is a single way of setting this
|
||||
value and thus avoid potential subtle errors.
|
||||
If you try to add it directly to the configuration file, you will receive an error from the ``Trainer`` - this is
|
||||
because this setting is needed by the ``Trainer`` too, and so this approach ensures that there is a single way of
|
||||
setting this value and thus avoid potential subtle errors.
|
||||
|
||||
|
||||
|
||||
@ -963,6 +1318,175 @@ Here is an example of the ``gradient_clipping`` configuration:
|
||||
|
||||
|
||||
|
||||
Getting the model weights out
|
||||
=======================================================================================================================
|
||||
|
||||
As long as you continue training and resuming using DeepSpeed you don't need to worry about anything. DeepSpeed stores
|
||||
fp32 master weights in its custom checkpoint optimizer files, which are ``global_step*/*optim_states.pt`` (this is glob
|
||||
pattern), and are saved under the normal checkpoint.
|
||||
|
||||
**FP16 Weights:**
|
||||
|
||||
When a model is saved under ZeRO-2, you end up having the normal ``pytorch_model.bin`` file with the model weights, but
|
||||
they are only the fp16 version of the weights.
|
||||
|
||||
Under ZeRO-3, things are much more complicated, since the model weights are partitioned out over multiple GPUs,
|
||||
therefore ``"stage3_gather_fp16_weights_on_model_save": true`` is required to get the ``Trainer`` to save the fp16
|
||||
version of the weights. If this setting is ``False`` ``pytorch_model.bin`` won't be created. This is because by default
|
||||
DeepSpeed's ``state_dict`` contains a placeholder and not the real weights. If we were to save this ``state_dict`` it
|
||||
won't be possible to load it back.
|
||||
|
||||
**FP32 Weights:**
|
||||
|
||||
While the fp16 weights are fine for resuming training, if you finished finetuning your model and want to upload it to
|
||||
the `models hub <https://huggingface.co/models>`__ or pass it to someone else you most likely will want to get the fp32
|
||||
weights. This cannot be done during training since this is a process that requires a lot of memory, and therefore this
|
||||
is performed offline.
|
||||
|
||||
DeepSpeed creates a special conversion script ``zero_to_fp32.py`` which it places in the top-level of the checkpoint
|
||||
folder. Using this script you can extract the weights at any point. The script is standalone and you no longer need to
|
||||
have the configuration file or a ``Trainer`` to do the extraction.
|
||||
|
||||
Let's say your checkpoint folder looks like this:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ls -l output_dir/checkpoint-1/
|
||||
-rw-rw-r-- 1 stas stas 1.4K Mar 27 20:42 config.json
|
||||
drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/
|
||||
-rw-rw-r-- 1 stas stas 12 Mar 27 13:16 latest
|
||||
-rw-rw-r-- 1 stas stas 827K Mar 27 20:42 optimizer.pt
|
||||
-rw-rw-r-- 1 stas stas 231M Mar 27 20:42 pytorch_model.bin
|
||||
-rw-rw-r-- 1 stas stas 623 Mar 27 20:42 scheduler.pt
|
||||
-rw-rw-r-- 1 stas stas 1.8K Mar 27 20:42 special_tokens_map.json
|
||||
-rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model
|
||||
-rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json
|
||||
-rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json
|
||||
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin
|
||||
-rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py*
|
||||
|
||||
In this example there is just one DeepSpeed checkpoint sub-folder `global_step1`. Therefore to reconstruct the fp32
|
||||
weights just run:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python zero_to_fp32.py global_step1 pytorch_model.bin
|
||||
|
||||
The script will automatically handle either ZeRO-2 or ZeRO-3 checkpoint.
|
||||
|
||||
``python zero_to_fp32.py -h`` will give you usage details.
|
||||
|
||||
If you have multiple DeepSpeed checkpoint sub-folders, pick the one you know to have the desired weights.
|
||||
|
||||
This is it. ``pytorch_model.bin`` will now contain the full fp32 model weights consolidated from multiple GPUs.
|
||||
|
||||
Note: currently the script requires 2x general RAM of the final fp32 model weights.
|
||||
|
||||
ZeRO 3 Nuances
|
||||
=======================================================================================================================
|
||||
|
||||
ZeRO 3 is quite different from ZeRO 2 because of its param sharding feature.
|
||||
|
||||
While all the efforts were made for things to just work without needing any special changes to your models, in certain
|
||||
circumstances you may find the following information to be needed.
|
||||
|
||||
|
||||
Registering External Parameters
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
If layer A needs to access weights belonging to layer B, currently layer A needs to tell DeepSpeed about it. This is
|
||||
done with the help of ``deepspeed.zero.register_external_parameter`` that needs to be called in ``A.__init__`` and can
|
||||
be seen in the following example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class ModuleZ3(torch.nn.Module):
|
||||
def __init__(self, *args):
|
||||
super().__init__(self, *args)
|
||||
self.layer1 = SomeLayer()
|
||||
self.layer2 = OtherLayer()
|
||||
deepspeed.zero.register_external_parameter(self, self.layer1.weight)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.layer1(input)
|
||||
# self.layer1.weight is needed in ModuleZ3.forward
|
||||
y = self.layer2(x, self.layer1.weight)
|
||||
return y
|
||||
|
||||
In general ``transformers`` models don't use this style of referring to other layer's weights so most likely you won't
|
||||
need to use it.
|
||||
|
||||
For full details on this method please refer to `Registering External Parameters
|
||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#registering-external-parameters>`__.
|
||||
|
||||
|
||||
|
||||
Constructing Massive Models
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
DeepSpeed/ZeRO-3 can handle models with Trillions of parameters which may not fit onto the existing RAM. In such cases,
|
||||
but also if you want the initialization to happen much faster, initialize the model using `deepspeed.zero.Init()`
|
||||
context manager (which is also a function decorator), like so:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers import T5ForConditionalGeneration, T5Config
|
||||
import deepspeed
|
||||
with deepspeed.zero.Init():
|
||||
config = T5Config.from_pretrained("t5-small")
|
||||
model = T5ForConditionalGeneration(config)
|
||||
|
||||
As you can see this gives you a randomly initialized model.
|
||||
|
||||
If you want to use a pretrained model, ``model_class.from_pretrained`` will activate this feature as long as
|
||||
``is_deepspeed_zero3_enabled()`` returns ``True``, which can be set manually via ``deepspeed_zero3_enable(True)``.
|
||||
Therefore to enable this feature here is the required sequence:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from transformers.integrations import deepspeed_zero3_enable
|
||||
deepspeed_zero3_enable(True)
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
If you're using ``Trainer`` command line arguments which include ``--deepspeed ds_config.json`` with ZeRO-3 config
|
||||
enabled, then you can skip ``deepspeed_zero3_enable(True)`` as it will try to discover whether it'll be run under
|
||||
ZeRO-3 and ``from_pretrained`` will automatically activate this feature.
|
||||
|
||||
Note: If the fp16 weights of the model can't fit onto the memory of a single GPU this feature must be used.
|
||||
|
||||
For full details on this method and other related features please refer to `Constructing Massive Models
|
||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#constructing-massive-models>`__.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Gathering Parameters
|
||||
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
Under ZeRO-3 on multiple GPUs no single GPU has all the parameters unless it's the parameters for the currently
|
||||
executing layer. So if you need to access all parameters from all layers at once there is a specific method to do it.
|
||||
Most likely you won't need it, but if you do please refer to `Gathering Parameters
|
||||
<https://deepspeed.readthedocs.io/en/latest/zero3.html#manual-parameter-coordination>`__
|
||||
|
||||
We do however use it internally in several places, one such example is when loading pretrained model weights in
|
||||
``from_pretrained``. We load one layer at a time and immediately partition it to all participating GPUs, as for very
|
||||
large models it won't be possible to load it on one GPU and then spread it out to multiple GPUs, due to memory
|
||||
limitations.
|
||||
|
||||
Also under ZeRO-3, if you write your own code and run into a model parameter weight that looks like:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tensor([1.], device='cuda:0', dtype=torch.float16, requires_grad=True)
|
||||
|
||||
stress on ``tensor([1.])``, or if you get an error where it says the parameter is of size ``1``, instead of some much
|
||||
larger multi-dimensional shape, this means that the parameter is partitioned and what you see is a ZeRO-3 placeholder.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Notes
|
||||
=======================================================================================================================
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 32,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
48
examples/tests/deepspeed/ds_config_zero3.json
Normal file
48
examples/tests/deepspeed/ds_config_zero3.json
Normal file
@ -0,0 +1,48 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"cpu_offload": true,
|
||||
"cpu_offload_params": true,
|
||||
"cpu_offload_use_pin_memory" : true,
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e14,
|
||||
"reduce_bucket_size": 0,
|
||||
"stage3_prefetch_bucket_size": 0,
|
||||
"stage3_param_persistence_threshold": 0,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [0.8, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 3e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -20,11 +20,12 @@ import sys
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import TrainingArguments
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureStd,
|
||||
CaptureLogger,
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
@ -43,6 +44,7 @@ from test_trainer import TrainerIntegrationCommon, get_regression_trainer # noq
|
||||
|
||||
set_seed(42)
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
T5_SMALL = "t5-small"
|
||||
|
||||
|
||||
def load_json(path):
|
||||
@ -61,6 +63,11 @@ def require_deepspeed(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
stages = [ZERO2, ZERO3]
|
||||
|
||||
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
@ -68,7 +75,19 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
This class is for testing directly via get_regression_trainer
|
||||
|
||||
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods which we can re-use here.
|
||||
It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
|
||||
which we can re-use here.
|
||||
|
||||
Important: this class' setup can only work with a single gpu because it runs within the current
|
||||
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.
|
||||
|
||||
Note: if any of the tests of this class get run there will be at least one gpu occupied by them
|
||||
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
|
||||
won't be released until this pytest worker exits.
|
||||
|
||||
This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
|
||||
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
|
||||
is not a bug.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
@ -81,18 +100,28 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.dist_env_1_gpu = dict(
|
||||
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
|
||||
)
|
||||
self.ds_config_file = f"{self.test_file_dir_str}/ds_config.json"
|
||||
with io.open(self.ds_config_file, "r", encoding="utf-8") as f:
|
||||
self.ds_config_dict = json.load(f)
|
||||
|
||||
def test_fake_notebook_no_launcher(self):
|
||||
# this setup emulates a notebook where a launcher needs to be emulated by hand
|
||||
with CaptureStd() as cs: # noqa
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file)
|
||||
trainer.train()
|
||||
# fixme:
|
||||
# assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
|
||||
self.ds_config_file = {}
|
||||
self.ds_config_file[ZERO2] = f"{self.test_file_dir_str}/ds_config_zero2.json"
|
||||
self.ds_config_file[ZERO3] = f"{self.test_file_dir_str}/ds_config_zero3.json"
|
||||
|
||||
# use self.get_config_dict(stage) to use these to ensure the original is not modified
|
||||
self.ds_config_dict = {}
|
||||
with io.open(self.ds_config_file[ZERO2], "r", encoding="utf-8") as f:
|
||||
self.ds_config_dict[ZERO2] = json.load(f)
|
||||
with io.open(self.ds_config_file[ZERO3], "r", encoding="utf-8") as f:
|
||||
self.ds_config_dict[ZERO3] = json.load(f)
|
||||
|
||||
def get_config_dict(self, stage):
|
||||
""" As the tests modify the dict, always make a copy """
|
||||
config = deepcopy(self.ds_config_dict[stage])
|
||||
if stage == ZERO3:
|
||||
# This setting slows things down, so don't enable it by default unless needed by a test.
|
||||
# It's in the file as a demo for users since we want everything to work out of the box even if slower.
|
||||
config["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = False
|
||||
return config
|
||||
|
||||
# --- These tests are enough to run on one of zero stages --- #
|
||||
|
||||
# Test various combos
|
||||
# 1. DS scheduler + DS optimizer: this is already tested by most other tests
|
||||
@ -103,12 +132,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_hf_scheduler_hf_optimizer(self):
|
||||
a = 0
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
|
||||
del ds_config_dict["scheduler"] # force default HF Trainer scheduler
|
||||
ds_config_dict["zero_optimization"]["cpu_offload"] = False
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
|
||||
ds_config_zero2_dict = self.get_config_dict(ZERO2)
|
||||
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
|
||||
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
|
||||
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False
|
||||
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_zero2_dict)
|
||||
trainer.train()
|
||||
new_a = trainer.model.a.item()
|
||||
self.assertNotEqual(new_a, a)
|
||||
@ -116,11 +145,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_ds_scheduler_hf_optimizer(self):
|
||||
a = 0
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
|
||||
ds_config_dict["zero_optimization"]["cpu_offload"] = False
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_dict)
|
||||
ds_config_zero2_dict = self.get_config_dict(ZERO2)
|
||||
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
|
||||
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False
|
||||
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_zero2_dict)
|
||||
trainer.train()
|
||||
new_a = trainer.model.a.item()
|
||||
self.assertNotEqual(new_a, a)
|
||||
@ -128,11 +157,11 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_hf_scheduler_ds_optimizer(self):
|
||||
# this combo is not possible at the moment
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
del ds_config_dict["scheduler"] # force default HF Trainer scheduler
|
||||
ds_config_dict["zero_optimization"]["cpu_offload"] = False
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
|
||||
ds_config_zero2_dict = self.get_config_dict(ZERO2)
|
||||
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
|
||||
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False
|
||||
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train()
|
||||
self.assertTrue("HF scheduler + DeepSpeed optimizer combination is not possible" in str(context.exception))
|
||||
@ -140,20 +169,38 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
def test_hf_optimizer_with_offload(self):
|
||||
# must not allow non-DS optimizer when using ZERO-offload
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
del ds_config_dict["optimizer"] # force default HF Trainer optimizer
|
||||
ds_config_dict["zero_optimization"]["cpu_offload"] = True
|
||||
ds_config_zero2_dict = self.get_config_dict(ZERO2)
|
||||
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
|
||||
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = True
|
||||
# sanity check - should the default config change
|
||||
assert (
|
||||
"cpu_offload" in ds_config_dict["zero_optimization"]
|
||||
and ds_config_dict["zero_optimization"]["cpu_offload"] is True
|
||||
"cpu_offload" in ds_config_zero2_dict["zero_optimization"]
|
||||
and ds_config_zero2_dict["zero_optimization"]["cpu_offload"] is True
|
||||
), "ensure the config is set up correctly"
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict)
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train()
|
||||
self.assertTrue("ZeRO Offload can only work with DeepSpeed optimizers" in str(context.exception))
|
||||
|
||||
def test_early_get_last_lr(self):
|
||||
# --- These tests need to run on both zero stages --- #
|
||||
@parameterized.expand(stages)
|
||||
def test_fake_notebook_no_launcher(self, stage):
|
||||
# this setup emulates a notebook where a launcher needs to be emulated by hand
|
||||
|
||||
# note that unittest resets sys.stdout each test, so `CaptureStd` will work here to capture
|
||||
# DeepSpeed log if this test happens to run first in this pytest worker. But it will fail if
|
||||
# it's run not as a first test as `sys.stdout` will no longer be the same. So we either have
|
||||
# to reset `logger.handlers[0].setStream(sys.stdout)` or directly capture from the logger.
|
||||
from deepspeed.utils import logger
|
||||
|
||||
with CaptureLogger(logger) as cs:
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
trainer = get_regression_trainer(local_rank=0, deepspeed=self.ds_config_file[stage])
|
||||
trainer.train()
|
||||
assert "DeepSpeed info" in cs.out, "expected DeepSpeed logger output but got none"
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_early_get_last_lr(self, stage):
|
||||
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||
# that time `get_last_lr` will fail if called during that warm up stage,
|
||||
@ -167,19 +214,24 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=8,
|
||||
deepspeed=self.ds_config_file,
|
||||
deepspeed=self.ds_config_file[stage],
|
||||
per_device_train_batch_size=8,
|
||||
logging_steps=1,
|
||||
)
|
||||
trainer.train()
|
||||
no_grad_accum_a = trainer.model.a.item()
|
||||
post_train_a = trainer.model.a.item()
|
||||
|
||||
# XXX: for some reason the following check fails with zero3 - not a broken but a
|
||||
# different qualitative outcome - need to investigate at some point
|
||||
if stage == ZERO3:
|
||||
return
|
||||
|
||||
# it's enough that train didn't fail for this test, but we must check that
|
||||
# optimizer/scheduler didn't run (since if it did this test isn't testing the right thing)
|
||||
self.assertEqual(no_grad_accum_a, a)
|
||||
|
||||
def test_gradient_accumulation(self):
|
||||
self.assertEqual(post_train_a, a)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_gradient_accumulation(self, stage):
|
||||
# this test measures that we get identical weights and similar loss with:
|
||||
# 1. per_device_train_batch_size=8, gradient_accumulation_steps=1
|
||||
# 2. per_device_train_batch_size=4, gradient_accumulation_steps=2
|
||||
@ -201,7 +253,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=train_len,
|
||||
deepspeed=self.ds_config_file,
|
||||
deepspeed=self.ds_config_file[stage],
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
@ -218,7 +270,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=train_len,
|
||||
deepspeed=self.ds_config_file,
|
||||
deepspeed=self.ds_config_file[stage],
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
)
|
||||
@ -235,34 +287,55 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
# see the note above how to get identical loss on a small bs
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
|
||||
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, is_pretrained=True):
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
|
||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||
|
||||
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
|
||||
ds_file_list = ["mp_rank_00_model_states.pt", "zero_pp_rank_0_mp_rank_00optim_states.pt"]
|
||||
|
||||
if stage == ZERO2:
|
||||
ds_file_list = ["mp_rank_00_model_states.pt"]
|
||||
elif stage == ZERO3:
|
||||
ds_file_list = ["zero_pp_rank_0_mp_rank_00_model_states.pt"]
|
||||
else:
|
||||
raise ValueError(f"unknown stage {stage}")
|
||||
|
||||
# XXX: this can be recoded and then removed once we require deepspeed>0.3.13
|
||||
from packaging import version
|
||||
|
||||
import deepspeed
|
||||
|
||||
if version.parse(deepspeed.__version__) > version.parse("0.3.13"):
|
||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00_optim_states.pt")
|
||||
else:
|
||||
ds_file_list.append("zero_pp_rank_0_mp_rank_00optim_states.pt")
|
||||
|
||||
for step in range(freq, total, freq):
|
||||
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
|
||||
self.assertTrue(os.path.isdir(checkpoint))
|
||||
self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")
|
||||
|
||||
# common files
|
||||
for filename in file_list:
|
||||
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
|
||||
path = os.path.join(checkpoint, filename)
|
||||
self.assertTrue(os.path.isfile(path), f"[{stage}] {path} is not found")
|
||||
|
||||
# ds files
|
||||
ds_path = os.path.join(checkpoint, f"global_step{step}")
|
||||
for filename in ds_file_list:
|
||||
# filename = os.path.join(path, filename)
|
||||
# print(filename)
|
||||
self.assertTrue(os.path.isfile(os.path.join(ds_path, filename)))
|
||||
path = os.path.join(ds_path, filename)
|
||||
self.assertTrue(os.path.isfile(path), f"[{stage}] {path} is not found")
|
||||
|
||||
def test_save_checkpoints(self):
|
||||
@parameterized.expand(stages)
|
||||
def test_save_checkpoints(self, stage):
|
||||
# adapted from TrainerIntegrationTest.test_save_checkpoints
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
freq = 5
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
if stage == ZERO3:
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = True
|
||||
|
||||
# save checkpoints
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
@ -274,14 +347,42 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer.train()
|
||||
|
||||
total = int(self.n_epochs * 64 / self.batch_size)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total)
|
||||
self.check_saved_checkpoints_deepspeed(output_dir, freq, total, stage)
|
||||
|
||||
def test_can_resume_training(self):
|
||||
@parameterized.expand(stages)
|
||||
def test_can_resume_training_errors(self, stage):
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
trainer = get_regression_trainer(output_dir=output_dir, deepspeed=ds_config_dict)
|
||||
|
||||
# 1. fail to find any checkpoint - due a fresh output_dir
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue(
|
||||
"No valid checkpoint found in output directory" in str(context.exception),
|
||||
f"got exception: {context.exception}",
|
||||
)
|
||||
|
||||
# 2. fail to find a bogus checkpoint
|
||||
with self.assertRaises(Exception) as context:
|
||||
checkpoint = os.path.join(output_dir, "checkpoint-5")
|
||||
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
|
||||
self.assertTrue(
|
||||
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
|
||||
)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_can_resume_training_normal(self, stage):
|
||||
# adapted from TrainerIntegrationTest.test_can_resume_training
|
||||
|
||||
# test normal resume for each stage separately, error-handling is tested in a different test
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
ds_config_dict = deepcopy(self.ds_config_dict)
|
||||
ds_config_dict = self.get_config_dict(stage)
|
||||
ds_config_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
|
||||
if stage == ZERO3:
|
||||
ds_config_dict["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = True
|
||||
|
||||
kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict)
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
@ -315,70 +416,117 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertEqual(b, b1)
|
||||
self.check_trainer_state_are_the_same(state, state1)
|
||||
|
||||
# Now check failures
|
||||
|
||||
# 1. fail to find a bogus checkpoint
|
||||
trainer = get_regression_trainer(**kwargs)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=f"{checkpoint}-bogus")
|
||||
self.assertTrue("failed to resume from checkpoint" in str(context.exception))
|
||||
|
||||
# 2. fail to find any checkpoint - due a fresh output_dir
|
||||
output_dir2 = self.get_auto_remove_tmp_dir()
|
||||
trainer = get_regression_trainer(output_dir=output_dir2, deepspeed=ds_config_dict)
|
||||
with self.assertRaises(Exception) as context:
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
self.assertTrue("No valid checkpoint found in output directory" in str(context.exception))
|
||||
|
||||
|
||||
@slow
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
class TestDeepSpeed(TestCasePlus):
|
||||
""" This class is for testing via an external script """
|
||||
class TestDeepSpeedWithLauncher(TestCasePlus):
|
||||
""" This class is for testing via an external script - can do multiple gpus """
|
||||
|
||||
# Tests to devise #
|
||||
#
|
||||
# 1. predict_with_generate on multigpu - need to figure out how to give input sequences so that
|
||||
# the 2 gpus will generate prediction sequences that aren't of the same length - this is because
|
||||
# we had to code a special feature to sync the gpus when the predicted sequences aren't of the
|
||||
# same length. In general this will tested as a side-effect through a variety of other tests -
|
||||
# it'll simply hang trying to synchronize with other gpus if this problem is encountered. So as
|
||||
# long as we have a few full tests running on zero3 + predict_with_generate this should be
|
||||
# mostly covered.
|
||||
#
|
||||
# but there are 5 variations on beam search in `generate`- with identical code branched with `if
|
||||
# synced_gpus`
|
||||
#
|
||||
# 2. most tests should probably be run on both: zero2 and zero3 configs
|
||||
#
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_basic_distributed(self):
|
||||
self.run_quick(distributed=True)
|
||||
@parameterized.expand(stages)
|
||||
def test_basic_distributed(self, stage):
|
||||
self.run_and_check(stage=stage, distributed=True)
|
||||
|
||||
def test_do_eval_no_train(self):
|
||||
@parameterized.expand(stages)
|
||||
def test_do_eval_no_train(self, stage):
|
||||
# we should not fail if train is skipped
|
||||
output_dir = self.run_trainer(
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
eval_steps=1,
|
||||
max_len=12,
|
||||
model_name=MBART_TINY,
|
||||
num_train_epochs=1,
|
||||
distributed=False,
|
||||
extra_args_str="--do_eval",
|
||||
remove_args_str="--do_train",
|
||||
do_train=False,
|
||||
do_eval=True,
|
||||
)
|
||||
val_metrics = load_json(os.path.join(output_dir, "eval_results.json"))
|
||||
assert "eval_bleu" in val_metrics
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_resume_train_not_from_ds_checkpoint(self, stage):
|
||||
# do normal training and then resume not from the deepspeed checkpoint but explicitly from
|
||||
# the saved model dir
|
||||
|
||||
do_train = True
|
||||
do_eval = False
|
||||
kwargs = dict(stage=stage, eval_steps=1, distributed=True, do_train=do_train, do_eval=do_eval)
|
||||
|
||||
# 1. normal training
|
||||
output_dir = self.run_and_check(**kwargs)
|
||||
|
||||
# 2. now resume explicitly from the saved weights, by passing --model_name_or_path output_dir
|
||||
# - i.e. the same path the model was saved to in step 1
|
||||
output_dir = self.run_trainer(**kwargs, model_name=output_dir)
|
||||
|
||||
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
|
||||
|
||||
def do_checks(self, output_dir, do_train=True, do_eval=True):
|
||||
|
||||
if do_train:
|
||||
train_metrics = load_json(os.path.join(output_dir, "train_results.json"))
|
||||
self.assertIn("train_samples_per_second", train_metrics)
|
||||
self.assertGreater(train_metrics["train_samples_per_second"], 0.5)
|
||||
|
||||
if do_eval:
|
||||
eval_metrics = load_json(os.path.join(output_dir, "eval_results.json"))
|
||||
self.assertIn("eval_bleu", eval_metrics)
|
||||
self.assertGreater(eval_metrics["eval_bleu"], 0)
|
||||
|
||||
# XXX: need to do better validation beyond just that the run was successful
|
||||
def run_quick(self, distributed=True, extra_args_str=None, remove_args_str=None):
|
||||
def run_and_check(
|
||||
self,
|
||||
stage,
|
||||
eval_steps=10,
|
||||
distributed=True,
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
extra_args_str=None,
|
||||
remove_args_str=None,
|
||||
):
|
||||
|
||||
# we are doing quality testing so using a small real model
|
||||
output_dir = self.run_trainer(
|
||||
eval_steps=1,
|
||||
max_len=12,
|
||||
model_name=MBART_TINY,
|
||||
stage=stage,
|
||||
model_name=T5_SMALL,
|
||||
eval_steps=eval_steps,
|
||||
num_train_epochs=1,
|
||||
do_train=do_train,
|
||||
do_eval=do_eval,
|
||||
distributed=distributed,
|
||||
extra_args_str=extra_args_str,
|
||||
remove_args_str=remove_args_str,
|
||||
)
|
||||
train_metrics = load_json(os.path.join(output_dir, "train_results.json"))
|
||||
assert "train_runtime" in train_metrics
|
||||
|
||||
self.do_checks(output_dir, do_train=do_train, do_eval=do_eval)
|
||||
|
||||
return output_dir
|
||||
|
||||
def run_trainer(
|
||||
self,
|
||||
eval_steps: int,
|
||||
max_len: str,
|
||||
stage: str,
|
||||
model_name: str,
|
||||
num_train_epochs: int,
|
||||
eval_steps: int = 10,
|
||||
num_train_epochs: int = 1,
|
||||
do_train: bool = False,
|
||||
do_eval: bool = True,
|
||||
distributed: bool = True,
|
||||
extra_args_str: str = None,
|
||||
remove_args_str: str = None,
|
||||
):
|
||||
max_len = 32
|
||||
data_dir = self.examples_dir / "test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
@ -387,41 +535,100 @@ class TestDeepSpeed(TestCasePlus):
|
||||
--validation_file {data_dir}/val.json
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--max_train_samples 8
|
||||
--max_val_samples 8
|
||||
--max_source_length {max_len}
|
||||
--max_target_length {max_len}
|
||||
--val_max_target_length {max_len}
|
||||
--do_train
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 4
|
||||
--learning_rate 3e-3
|
||||
--warmup_steps 8
|
||||
--predict_with_generate
|
||||
--logging_steps 0
|
||||
--save_steps {str(eval_steps)}
|
||||
--save_steps 0
|
||||
--eval_steps {eval_steps}
|
||||
--group_by_length
|
||||
--label_smoothing_factor 0.1
|
||||
--adafactor
|
||||
--target_lang ro_RO
|
||||
--source_lang en_XX
|
||||
--source_lang en
|
||||
--target_lang ro
|
||||
""".split()
|
||||
args.extend(["--source_prefix", '"translate English to Romanian: "'])
|
||||
|
||||
actions = 0
|
||||
if do_train:
|
||||
actions += 1
|
||||
args.extend(
|
||||
f"""
|
||||
--do_train
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--max_train_samples 100
|
||||
--per_device_train_batch_size 2
|
||||
--learning_rate 3e-3
|
||||
""".split()
|
||||
)
|
||||
|
||||
if do_eval:
|
||||
actions += 1
|
||||
args.extend(
|
||||
"""
|
||||
--do_eval
|
||||
--max_val_samples 100
|
||||
--per_device_eval_batch_size 2
|
||||
""".split()
|
||||
)
|
||||
|
||||
assert actions > 0, "need at least do_train or do_eval for the test to run"
|
||||
|
||||
if extra_args_str is not None:
|
||||
args.extend(extra_args_str.split())
|
||||
|
||||
# currently only works for bool args
|
||||
if remove_args_str is not None:
|
||||
remove_args = remove_args_str.split()
|
||||
args = [x for x in args if x not in remove_args]
|
||||
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
||||
script = [f"{self.examples_dir_str}/seq2seq/run_translation.py"]
|
||||
num_gpus = get_gpu_count() if distributed else 1
|
||||
launcher = f"deepspeed --num_gpus {num_gpus}".split()
|
||||
|
||||
cmd = launcher + script + args + ds_args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"PYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
return output_dir
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_clm(self, stage):
|
||||
# this test exercises model.resize_token_embeddings() which requires param gathering outside
|
||||
# of forward - it's not used by `run_translation.py`, but it is in `run_clm.py`
|
||||
|
||||
data_dir = self.tests_dir / "fixtures"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
args = f"""
|
||||
--model_name_or_path sshleifer/tiny-gpt2
|
||||
--train_file {data_dir}/sample_text.txt
|
||||
--validation_file {data_dir}/sample_text.txt
|
||||
--output_dir {output_dir}
|
||||
--overwrite_output_dir
|
||||
--do_train
|
||||
--do_eval
|
||||
--max_train_samples 10
|
||||
--max_val_samples 10
|
||||
--per_device_train_batch_size 5
|
||||
--per_device_eval_batch_size 5
|
||||
--num_train_epochs 1
|
||||
--warmup_steps 8
|
||||
--block_size 128
|
||||
""".split()
|
||||
|
||||
distributed = True
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
|
||||
script = [f"{self.examples_dir_str}/language-modeling/run_clm.py"]
|
||||
num_gpus = get_gpu_count() if distributed else 1
|
||||
launcher = f"deepspeed --num_gpus {num_gpus}".split()
|
||||
|
||||
cmd = launcher + script + args + ds_args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
return output_dir
|
||||
|
@ -18,6 +18,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
@ -695,6 +696,7 @@ class GenerationMixin:
|
||||
forced_bos_token_id: Optional[int] = None,
|
||||
forced_eos_token_id: Optional[int] = None,
|
||||
remove_invalid_values: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -800,6 +802,8 @@ class GenerationMixin:
|
||||
remove_invalid_values (:obj:`bool`, `optional`):
|
||||
Whether to remove possible `nan` and `inf` outputs of the model to prevent the generation method to
|
||||
crash. Note that using ``remove_invalid_values`` can slow down generation.
|
||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If the
|
||||
@ -1000,6 +1004,7 @@ class GenerationMixin:
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1028,6 +1033,7 @@ class GenerationMixin:
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1063,6 +1069,7 @@ class GenerationMixin:
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1102,6 +1109,7 @@ class GenerationMixin:
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1141,6 +1149,7 @@ class GenerationMixin:
|
||||
eos_token_id=eos_token_id,
|
||||
output_scores=output_scores,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -1156,13 +1165,12 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GreedySearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head using greedy decoding.
|
||||
|
||||
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
@ -1175,6 +1183,7 @@ class GenerationMixin:
|
||||
stopping_criteria (:obj:`StoppingCriteriaList`, `optional`):
|
||||
An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from
|
||||
:class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop.
|
||||
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
@ -1191,6 +1200,8 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
model_kwargs:
|
||||
Additional model specific keyword arguments will be forwarded to the :obj:`forward` function of the
|
||||
model. If model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
@ -1265,7 +1276,19 @@ class GenerationMixin:
|
||||
input_ids, max_length
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
@ -1276,6 +1299,11 @@ class GenerationMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
@ -1321,16 +1349,16 @@ class GenerationMixin:
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return GreedySearchEncoderDecoderOutput(
|
||||
@ -1365,6 +1393,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[SampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -1402,6 +1431,8 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
@ -1485,8 +1516,20 @@ class GenerationMixin:
|
||||
input_ids, max_length
|
||||
)
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
# auto-regressive generation
|
||||
while cur_len < max_length:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
@ -1497,6 +1540,11 @@ class GenerationMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
@ -1533,7 +1581,6 @@ class GenerationMixin:
|
||||
|
||||
# add token and increase length by one
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# update sequence length
|
||||
if eos_token_id is not None:
|
||||
@ -1541,18 +1588,21 @@ class GenerationMixin:
|
||||
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
|
||||
)
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
|
||||
# update model kwargs
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
# stop when there is a </s> in each sentence, or if we exceed the maximum length
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
if return_dict_in_generate:
|
||||
if self.config.is_encoder_decoder:
|
||||
return SampleEncoderDecoderOutput(
|
||||
@ -1587,6 +1637,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSearchOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -1624,6 +1675,8 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
@ -1726,7 +1779,19 @@ class GenerationMixin:
|
||||
beam_scores[:, 1:] = -1e9
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
@ -1735,6 +1800,11 @@ class GenerationMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
@ -1792,19 +1862,20 @@ class GenerationMixin:
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
cur_len = cur_len + 1
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
@ -1849,6 +1920,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[BeamSampleOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@ -1890,6 +1962,8 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
@ -1993,7 +2067,19 @@ class GenerationMixin:
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
@ -2002,6 +2088,11 @@ class GenerationMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
||||
@ -2063,7 +2154,6 @@ class GenerationMixin:
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
cur_len = cur_len + 1
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
@ -2071,11 +2161,14 @@ class GenerationMixin:
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
|
||||
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
@ -2119,6 +2212,7 @@ class GenerationMixin:
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -2156,6 +2250,9 @@ class GenerationMixin:
|
||||
Whether or not to return the prediction scores. See ``scores`` under returned tensors for more details.
|
||||
return_dict_in_generate (:obj:`bool`, `optional`, defaults to `False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
synced_gpus (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs that will be forwarded to the :obj:`forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include :obj:`encoder_outputs`.
|
||||
@ -2266,7 +2363,19 @@ class GenerationMixin:
|
||||
beam_scores[:, ::num_sub_beams] = 0
|
||||
beam_scores = beam_scores.view((batch_size * num_beams,))
|
||||
|
||||
this_peer_finished = False # used by synced_gpus only
|
||||
while cur_len < max_length:
|
||||
|
||||
if synced_gpus:
|
||||
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# The following logic allows an early break if all peers finished generating their sequence
|
||||
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# send 0.0 if we finished, 1.0 otherwise
|
||||
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# did all peers finish? the reduced sum will be 0.0 then
|
||||
if this_peer_finished_flag.item() == 0.0:
|
||||
break
|
||||
|
||||
# predicted tokens in cur_len step
|
||||
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
|
||||
|
||||
@ -2282,6 +2391,10 @@ class GenerationMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
||||
@ -2372,19 +2485,22 @@ class GenerationMixin:
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
)
|
||||
if model_kwargs["past"] is not None:
|
||||
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], reordering_indices)
|
||||
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
# increase cur_len
|
||||
cur_len = cur_len + 1
|
||||
if beam_scorer.is_done:
|
||||
break
|
||||
|
||||
if stopping_criteria(input_ids, scores):
|
||||
break
|
||||
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
||||
if not synced_gpus:
|
||||
break
|
||||
else:
|
||||
this_peer_finished = True
|
||||
|
||||
sequence_outputs = beam_scorer.finalize(
|
||||
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id
|
||||
|
@ -19,6 +19,7 @@ import io
|
||||
import json
|
||||
import numbers
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@ -268,7 +269,77 @@ def rewrite_logs(d):
|
||||
return new_d
|
||||
|
||||
|
||||
def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
_is_deepspeed_zero3_enabled = None
|
||||
|
||||
|
||||
def is_deepspeed_zero3_enabled():
|
||||
"""
|
||||
This function answers to the question of whether DeepSpeed is going to be used and run using ZeRO Stage 3.
|
||||
|
||||
It includes an auto-discovery method, see comments in the code for details.
|
||||
|
||||
Returns: ``True`` if either it was explicitly enabled via ``deepspeed_zero3_enable(True)`` or the auto-detector was
|
||||
able to derive that the ``Trainer`` will be running via DeepSpeed ZeRO stage 3.
|
||||
"""
|
||||
global _is_deepspeed_zero3_enabled
|
||||
if _is_deepspeed_zero3_enabled is None:
|
||||
_is_deepspeed_zero3_enabled = False
|
||||
# Try to auto-discover if we are about to use DeepSpeed with ZeRO3 enabled. This will only
|
||||
# work for scripts using cli to pass --deepspeed ds_config.json. If cmd args aren't used,
|
||||
# then to get the model efficiently loaded across multiple-gpus one has to explicitly call
|
||||
# is_deepspeed_zero3_enabled(True) **before** instantiating a model object
|
||||
if "--deepspeed" in sys.argv:
|
||||
idx = sys.argv.index("--deepspeed")
|
||||
ds_config = sys.argv[idx + 1]
|
||||
if not os.path.exists(ds_config):
|
||||
raise ValueError("--deepspeed requires a valid path to a config file")
|
||||
config = deepspeed_parse_config(ds_config)
|
||||
if (
|
||||
"zero_optimization" in config
|
||||
and "stage" in config["zero_optimization"]
|
||||
and config["zero_optimization"]["stage"] == 3
|
||||
):
|
||||
_is_deepspeed_zero3_enabled = True
|
||||
|
||||
return _is_deepspeed_zero3_enabled
|
||||
|
||||
|
||||
def deepspeed_zero3_enable(enable=True):
|
||||
"""
|
||||
``is_deepspeed_zero3_enabled()`` tries to derive automatically if DeepSpeed ZeRO 3 is going to be used by looking
|
||||
at ``sys.argv`` which may or may contain information about where to find the DeepSpeed config if any.
|
||||
|
||||
This function allows for explicit enabling/disabling of this global flag.
|
||||
|
||||
Args:
|
||||
enable: if set to ``True`` will make ``is_deepspeed_zero3_enabled()`` return ``True``
|
||||
"""
|
||||
global _is_deepspeed_zero3_enabled
|
||||
_is_deepspeed_zero3_enabled = enable
|
||||
|
||||
|
||||
def deepspeed_parse_config(ds_config):
|
||||
"""
|
||||
If ``ds_config`` isn't already a dict, read it from the config file.
|
||||
|
||||
If it's already a dict, return a copy of it, so that we can freely modify it.
|
||||
"""
|
||||
require_version("deepspeed>0.3.13")
|
||||
|
||||
if isinstance(ds_config, dict):
|
||||
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
||||
# modified it, it will not be accepted here again, since some config params must be not set by users
|
||||
config = deepcopy(ds_config)
|
||||
elif isinstance(ds_config, str):
|
||||
with io.open(ds_config, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
raise ValueError("expecting either a path to a config file or a pre-populated dict")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
"""
|
||||
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
|
||||
|
||||
@ -284,21 +355,10 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
"""
|
||||
import deepspeed
|
||||
|
||||
require_version("deepspeed>0.3.12")
|
||||
|
||||
args = trainer.args
|
||||
ds_config_file = args.deepspeed
|
||||
model = trainer.model
|
||||
|
||||
if isinstance(args.deepspeed, dict):
|
||||
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
|
||||
# modified it, it will not be accepted here again, since some config params must be not set by users
|
||||
config = deepcopy(args.deepspeed)
|
||||
elif isinstance(args.deepspeed, str):
|
||||
with io.open(ds_config_file, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
raise ValueError("expecting either a path to a config file or a pre-populated dict")
|
||||
config = deepspeed_parse_config(args.deepspeed)
|
||||
|
||||
# The following code translates relevant trainer's cl args into the DS config
|
||||
|
||||
@ -324,9 +384,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
|
||||
|
||||
if "gradient_clipping" in config:
|
||||
logger.info(
|
||||
f"Keeping the `gradient_clipping` config from {ds_config_file} intact, ignoring any gradient clipping-specific cl args"
|
||||
)
|
||||
logger.info("Keeping the `gradient_clipping` config intact, ignoring any gradient clipping-specific cl args")
|
||||
else: # override only if the ds config doesn't already have this section
|
||||
config["gradient_clipping"] = args.max_grad_norm
|
||||
|
||||
@ -336,6 +394,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
# 2. HF scheduler + HF optimizer: Yes
|
||||
# 3. DS scheduler + HF optimizer: Yes
|
||||
# 4. HF scheduler + DS optimizer: No
|
||||
#
|
||||
# Unless Offload is enabled in which case it's:
|
||||
# 1. DS scheduler + DS optimizer: Yes
|
||||
# 2. HF scheduler + HF optimizer: No
|
||||
@ -344,7 +403,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
|
||||
optimizer = None
|
||||
if "optimizer" in config:
|
||||
logger.info(f"Updating the `scheduler` config from {ds_config_file} with other command line arguments")
|
||||
logger.info("Updating the `scheduler` config with other command line arguments")
|
||||
|
||||
# to avoid inconsistent values of lr and warm up steps the command line args override config
|
||||
params = dict(
|
||||
@ -384,7 +443,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
|
||||
lr_scheduler = None
|
||||
if "scheduler" in config:
|
||||
logger.info(f"Updating the `scheduler` config from {ds_config_file} with other command line arguments")
|
||||
logger.info("Updating the `scheduler` config with other command line arguments")
|
||||
# the user won't easily know the correct num_training_steps should they use WarmupDecayLR,
|
||||
# so let's set it to the correct value
|
||||
if config["scheduler"]["type"] == "WarmupDecayLR":
|
||||
@ -417,9 +476,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
# - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
|
||||
if trainer.fp16_backend == "apex":
|
||||
if "amp" in config:
|
||||
logger.info(
|
||||
f"Keeping the `amp` config from {ds_config_file} intact, ignoring any amp-specific cl args"
|
||||
)
|
||||
logger.info("Keeping the `amp` config intact, ignoring any amp-specific cl args")
|
||||
else:
|
||||
config["amp"] = {
|
||||
"enabled": True,
|
||||
@ -427,19 +484,33 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
}
|
||||
elif trainer.fp16_backend == "amp":
|
||||
if "fp16" in config:
|
||||
logger.info(
|
||||
f"Keeping the `fp16` config from {ds_config_file} intact, ignoring any fp16-specific cl args"
|
||||
)
|
||||
logger.info("Keeping the `fp16` config intact, ignoring any fp16-specific cl args")
|
||||
else:
|
||||
config["fp16"] = {
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
# zero
|
||||
if "zero_optimization" in config:
|
||||
zero = config["zero_optimization"]
|
||||
|
||||
# now we know for sure if zero3 is enabled
|
||||
deepspeed_zero3_enable(zero.get("stage") == 3)
|
||||
|
||||
# automatically assign the optimal config values based on model config
|
||||
hidden_size = model.config.hidden_size
|
||||
if zero.get("reduce_bucket_size") == 0:
|
||||
zero["reduce_bucket_size"] = hidden_size * hidden_size
|
||||
if zero.get("stage3_prefetch_bucket_size") == 0:
|
||||
zero["stage3_prefetch_bucket_size"] = 0.9 * hidden_size * hidden_size
|
||||
if zero.get("stage3_param_persistence_threshold") == 0:
|
||||
zero["stage3_param_persistence_threshold"] = 10 * hidden_size
|
||||
|
||||
# keep for quick debug:
|
||||
# from pprint import pprint; pprint(config)
|
||||
|
||||
# init that takes part of the config via `args`, and the bulk of it via `config_params`
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
|
||||
model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
||||
model=model,
|
||||
model_parameters=model_parameters,
|
||||
@ -448,14 +519,26 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
|
||||
if resume_from_checkpoint is not None: # and os.path.isdir(resume_from_checkpoint):
|
||||
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
|
||||
# this magically updates self.optimizer and self.lr_scheduler
|
||||
load_path, _ = model.load_checkpoint(
|
||||
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
||||
)
|
||||
if load_path is None:
|
||||
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
|
||||
if resume_from_checkpoint is not None:
|
||||
|
||||
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
|
||||
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
|
||||
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
|
||||
# path contains what looks like a deepspeed checkpoint
|
||||
import glob
|
||||
|
||||
deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))
|
||||
|
||||
if len(deepspeed_checkpoint_dirs) > 0:
|
||||
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
|
||||
# this magically updates self.optimizer and self.lr_scheduler
|
||||
load_path, _ = model.load_checkpoint(
|
||||
resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
|
||||
)
|
||||
if load_path is None:
|
||||
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
|
||||
else:
|
||||
logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
|
||||
|
||||
return model, optimizer, lr_scheduler
|
||||
|
||||
|
@ -41,6 +41,7 @@ from .file_utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .generation_utils import GenerationMixin
|
||||
from .integrations import is_deepspeed_zero3_enabled
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@ -660,7 +661,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
else:
|
||||
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
||||
|
||||
if old_num_tokens == new_num_tokens:
|
||||
return old_embeddings
|
||||
|
||||
@ -677,8 +685,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
self._init_weights(new_embeddings)
|
||||
|
||||
# Copy token embeddings from the previous weights
|
||||
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
||||
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
||||
|
||||
# numbers of tokens to copy
|
||||
n = min(old_num_tokens, new_num_tokens)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||
else:
|
||||
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
|
||||
|
||||
return new_embeddings
|
||||
|
||||
@ -1056,7 +1073,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
# Instantiate model.
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
# this immediately partitions the model to avoid the overhead in time and memory copying it on CPU or each GPU first
|
||||
with deepspeed.zero.Init():
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
else:
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if state_dict is None and not from_tf:
|
||||
try:
|
||||
@ -1114,15 +1140,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: nn.Module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
# because zero3 puts placeholders in model params, this context
|
||||
# manager gathers (unpartitions) the params of the current layer, then loads from
|
||||
# the state dict and then re-partitions them again
|
||||
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
module._load_from_state_dict(*args)
|
||||
else:
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
@ -17,7 +17,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
||||
"""
|
||||
|
||||
import collections
|
||||
import gc
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
@ -41,7 +40,8 @@ from .integrations import ( # isort: split
|
||||
is_ray_tune_available,
|
||||
run_hp_search_optuna,
|
||||
run_hp_search_ray,
|
||||
init_deepspeed,
|
||||
deepspeed_init,
|
||||
is_deepspeed_zero3_enabled,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -921,7 +921,7 @@ class Trainer:
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||
|
||||
if self.deepspeed:
|
||||
# will be resumed in init_deepspeed
|
||||
# will be resumed in deepspeed_init
|
||||
pass
|
||||
elif isinstance(self.model, PreTrainedModel):
|
||||
self.model = self.model.from_pretrained(resume_from_checkpoint)
|
||||
@ -965,12 +965,12 @@ class Trainer:
|
||||
|
||||
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||
if self.args.deepspeed:
|
||||
model, optimizer, lr_scheduler = init_deepspeed(
|
||||
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
|
||||
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
|
||||
)
|
||||
self.model = model.module
|
||||
self.model_wrapped = model
|
||||
self.deepspeed = model # DeepSpeedEngine object
|
||||
self.model = deepspeed_engine.module
|
||||
self.model_wrapped = deepspeed_engine
|
||||
self.deepspeed = deepspeed_engine
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
elif not delay_optimizer_creation:
|
||||
@ -1227,18 +1227,6 @@ class Trainer:
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
|
||||
if self.deepspeed:
|
||||
# free up any memory that might be useful for eval
|
||||
self.deepspeed = None
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
self.model_wrapped = self.model
|
||||
gc.collect() # force memory release
|
||||
# to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed
|
||||
self.place_model_on_device = self.args.place_model_on_device
|
||||
if self.is_model_parallel:
|
||||
self.place_model_on_device = False
|
||||
|
||||
self.is_in_train = False
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(metrics)
|
||||
@ -1293,6 +1281,8 @@ class Trainer:
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
self.save_model(output_dir)
|
||||
if self.deepspeed:
|
||||
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
|
||||
# config `stage3_gather_fp16_weights_on_model_save` is True
|
||||
self.deepspeed.save_checkpoint(output_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
@ -1351,7 +1341,7 @@ class Trainer:
|
||||
return
|
||||
|
||||
if self.deepspeed:
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in init_deepspeed
|
||||
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
|
||||
return
|
||||
|
||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
||||
@ -1597,6 +1587,10 @@ class Trainer:
|
||||
|
||||
Will only save from the main process.
|
||||
"""
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = self.args.output_dir
|
||||
|
||||
if is_torch_tpu_available():
|
||||
self._save_tpu(output_dir)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
@ -1608,8 +1602,31 @@ class Trainer:
|
||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
||||
):
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if self.is_world_process_zero():
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
elif self.deepspeed:
|
||||
|
||||
# this takes care of everything as long as we aren't under zero3
|
||||
if self.is_world_process_zero():
|
||||
self._save(output_dir)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
# It's too complicated to try to override different places where the weights dump gets
|
||||
# saved, so since under zero3 the file is bogus, simply delete it. The user should
|
||||
# either user deepspeed checkpoint to resume or to recover full weights use
|
||||
# zero_to_fp32.py stored in the checkpoint.
|
||||
if self.is_world_process_zero():
|
||||
file = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
if os.path.isfile(file):
|
||||
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
|
||||
os.remove(file)
|
||||
|
||||
# now save the real model if stage3_gather_fp16_weights_on_model_save=True
|
||||
# if false it will not be saved.
|
||||
# This must be called on all ranks
|
||||
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
|
||||
|
||||
elif self.is_world_process_zero():
|
||||
self._save(output_dir)
|
||||
|
||||
@ -1848,10 +1865,20 @@ class Trainer:
|
||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||
)
|
||||
|
||||
if self.args.deepspeed and not self.args.do_train:
|
||||
# no harm, but flagging to the user that deepspeed config is ignored for eval
|
||||
# flagging only for when --do_train wasn't passed as only then it's redundant
|
||||
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
|
||||
# if eval is called w/o train init deepspeed here
|
||||
if self.args.deepspeed and not self.deepspeed:
|
||||
|
||||
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
|
||||
# from the checkpoint eventually
|
||||
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
|
||||
self.model = deepspeed_engine.module
|
||||
self.model_wrapped = deepspeed_engine
|
||||
self.deepspeed = deepspeed_engine
|
||||
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
|
||||
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
|
||||
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
|
||||
deepspeed_engine.optimizer.optimizer = None
|
||||
deepspeed_engine.lr_scheduler = None
|
||||
|
||||
model = self._wrap_model(self.model, training=False)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from .integrations import is_deepspeed_zero3_enabled
|
||||
from .trainer import Trainer
|
||||
from .trainer_utils import PredictionOutput
|
||||
from .utils import logging
|
||||
@ -156,9 +157,11 @@ class Seq2SeqTrainer(Trainer):
|
||||
has_labels = "labels" in inputs
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
# XXX: adapt synced_gpus for fairscale as well
|
||||
gen_kwargs = {
|
||||
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
|
||||
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
||||
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
|
||||
}
|
||||
|
||||
generated_tokens = self.model.generate(
|
||||
|
@ -132,6 +132,7 @@ class RegressionModelConfig(PretrainedConfig):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.double_output = double_output
|
||||
self.hidden_size = 1
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
Loading…
Reference in New Issue
Block a user