diff --git a/examples/research_projects/wav2vec2/README.md b/examples/research_projects/wav2vec2/README.md index c1b9f8a6adf..39bbda38118 100644 --- a/examples/research_projects/wav2vec2/README.md +++ b/examples/research_projects/wav2vec2/README.md @@ -127,3 +127,60 @@ logs references and predictions. Using the Buckwalter format, text is also logge `--max_duration_in_seconds="15"` filters out examples whose audio is longer than the specified limit, which helps with capping GPU memory usage. + + +### DeepSpeed Integration + +To learn how to deploy Deepspeed Integration please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration). + +But to get started quickly all you need is to install: +``` +pip install deepspeed +``` +and then use the default configuration files in this directory: + +* `ds_config_wav2vec2_zero2.json` +* `ds_config_wav2vec2_zero3.json` + +Here are examples of how you can use DeepSpeed: + +(edit the value for `--num_gpus` to match the number of GPUs you have) + +ZeRO-2: + +``` +PYTHONPATH=../../../src deepspeed --num_gpus 2 \ +run_asr.py \ +--output_dir=output_dir --num_train_epochs=2 --per_device_train_batch_size=2 \ +--per_device_eval_batch_size=2 --evaluation_strategy=steps --save_steps=500 --eval_steps=100 \ +--logging_steps=5 --learning_rate=5e-4 --warmup_steps=3000 \ +--model_name_or_path=patrickvonplaten/wav2vec2_tiny_random_robust \ +--dataset_name=patrickvonplaten/librispeech_asr_dummy --dataset_config_name=clean \ +--train_split_name=validation --validation_split_name=validation --orthography=timit \ +--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \ +--deepspeed ds_config_wav2vec2_zero2.json +``` + +For ZeRO-2 with more than 1 gpu you need to use (which is already in the example configuration file): +``` + "zero_optimization": { + ... + "find_unused_parameters": true, + ... + } +``` + +ZeRO-3: + +``` +PYTHONPATH=../../../src deepspeed --num_gpus 2 \ +run_asr.py \ +--output_dir=output_dir --num_train_epochs=2 --per_device_train_batch_size=2 \ +--per_device_eval_batch_size=2 --evaluation_strategy=steps --save_steps=500 --eval_steps=100 \ +--logging_steps=5 --learning_rate=5e-4 --warmup_steps=3000 \ +--model_name_or_path=patrickvonplaten/wav2vec2_tiny_random_robust \ +--dataset_name=patrickvonplaten/librispeech_asr_dummy --dataset_config_name=clean \ +--train_split_name=validation --validation_split_name=validation --orthography=timit \ +--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \ +--deepspeed ds_config_wav2vec2_zero3.json +``` diff --git a/examples/research_projects/wav2vec2/ds_config_wav2vec2_zero2.json b/examples/research_projects/wav2vec2/ds_config_wav2vec2_zero2.json new file mode 100644 index 00000000000..6745e9917a3 --- /dev/null +++ b/examples/research_projects/wav2vec2/ds_config_wav2vec2_zero2.json @@ -0,0 +1,51 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "find_unused_parameters": true, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/examples/research_projects/wav2vec2/ds_config_wav2vec2_zero3.json b/examples/research_projects/wav2vec2/ds_config_wav2vec2_zero3.json new file mode 100644 index 00000000000..a80a173b7a9 --- /dev/null +++ b/examples/research_projects/wav2vec2/ds_config_wav2vec2_zero3.json @@ -0,0 +1,57 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_fp16_weights_on_model_save": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py b/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py new file mode 100644 index 00000000000..0580d1c4b12 --- /dev/null +++ b/examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py @@ -0,0 +1,185 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# XXX: we want transformers master here - in the absense of conftest manipulating sys.path: +# hack it in for now: +import sys +from pathlib import Path + + +git_repo_path = Path(__file__).resolve().parents[3] / "src" +sys.path.insert(1, str(git_repo_path)) + +import dataclasses # noqa +import io # noqa +import json # noqa +import os # noqa +import unittest # noqa +from copy import deepcopy # noqa + +from parameterized import parameterized # noqa +from transformers import TrainingArguments, is_torch_available # noqa +from transformers.deepspeed import is_deepspeed_available # noqa +from transformers.file_utils import WEIGHTS_NAME # noqa +from transformers.testing_utils import ( # noqa + CaptureLogger, + ExtendSysPath, + TestCasePlus, + execute_subprocess_async, + get_gpu_count, + mockenv_context, + require_deepspeed, + require_torch_gpu, + require_torch_multi_gpu, + slow, +) +from transformers.trainer_utils import set_seed # noqa + + +set_seed(42) + +WAV2VEC2_TINY = "patrickvonplaten/wav2vec2_tiny_random_robust" + + +ZERO2 = "zero2" +ZERO3 = "zero3" +stages = [ZERO2, ZERO3] + + +@slow +@require_deepspeed +@require_torch_gpu +class TestDeepSpeedWav2Vec2(TestCasePlus): + @parameterized.expand(stages) + def test_fp32_non_distributed(self, stage): + self.run_and_check( + stage=stage, + distributed=False, + fp16=False, + ) + + @require_torch_multi_gpu + @parameterized.expand(stages) + def test_fp32_distributed(self, stage): + self.run_and_check( + stage=stage, + distributed=True, + fp16=False, + ) + + @parameterized.expand(stages) + def test_fp16_non_distributed(self, stage): + self.run_and_check( + stage=stage, + distributed=False, + fp16=True, + ) + + @require_torch_multi_gpu + @parameterized.expand(stages) + def test_fp16_distributed(self, stage): + self.run_and_check( + stage=stage, + distributed=True, + fp16=True, + ) + + def do_checks(self, output_dir): + # XXX: run_asr is premature and doesn't save any results + # so all we check for now is that the process didn't fail + pass + + # XXX: need to do better validation beyond just that the run was successful + def run_and_check( + self, + stage, + model_name: str = WAV2VEC2_TINY, + eval_steps: int = 10, + distributed: bool = True, + quality_checks: bool = True, + fp16: bool = True, + ): + + output_dir = self.run_trainer( + stage=stage, + model_name=model_name, + eval_steps=eval_steps, + num_train_epochs=1, + distributed=distributed, + fp16=fp16, + ) + + self.do_checks(output_dir) + + return output_dir + + def run_trainer( + self, + stage: str, + model_name: str, + eval_steps: int = 10, + num_train_epochs: int = 1, + distributed: bool = True, + fp16: bool = True, + ): + + output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) + args = f""" + --model_name_or_path {model_name} + --dataset_name patrickvonplaten/librispeech_asr_dummy + --dataset_config_name clean + --train_split_name validation + --validation_split_name validation + --output_dir {output_dir} + --num_train_epochs {str(num_train_epochs)} + --per_device_train_batch_size 2 + --per_device_eval_batch_size 2 + --evaluation_strategy steps + --learning_rate 5e-4 + --warmup_steps 8 + --orthography timit + --preprocessing_num_workers 1 + --group_by_length + --freeze_feature_extractor + --report_to none + --logging_steps 0 + --save_steps 0 + --eval_steps {eval_steps} + --report_to none + """.split() + + if fp16: + args.extend(["--fp16"]) + + # currently ds_config_wav2vec2_zero.json requires "zero_optimization.find_unused_parameters": true, + # hence the separate config files + ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_wav2vec2_{stage}.json".split() + script = [f"{self.examples_dir_str}/research_projects/wav2vec2/run_asr.py"] + launcher = self.get_launcher(distributed) + + cmd = launcher + script + args + ds_args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + return output_dir + + def get_launcher(self, distributed=False): + # 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup + # - it won't be able to handle that + # 2. for now testing with just 2 gpus max (since some quality tests may give different + # results with mode gpus because we use very little data) + num_gpus = min(2, get_gpu_count()) if distributed else 1 + return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split() diff --git a/setup.py b/setup.py index 4fe1672e0a1..b8ed916b0e0 100644 --- a/setup.py +++ b/setup.py @@ -90,7 +90,7 @@ _deps = [ "cookiecutter==1.7.2", "dataclasses", "datasets", - "deepspeed>=0.3.16", + "deepspeed>=0.4.0", "docutils==0.16.0", "fairscale>0.3", "faiss-cpu", diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 7e91fc6d08d..4fe293dad76 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -23,9 +23,13 @@ from copy import deepcopy from functools import partialmethod from .dependency_versions_check import dep_version_check +from .file_utils import is_torch_available from .utils import logging +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) @@ -70,46 +74,86 @@ class HfDeepSpeedConfig: # zero stage - this is done as early as possible, before model is created, to allow # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object # during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc. - config_zero = config.get("zero_optimization", {}) - self.stage = config_zero.get("stage", 0) + self._stage = self.get_value("zero_optimization.stage", -1) # offload - self.offload = False - config_zero = config.get("zero_optimization", {}) + self._offload = False if self.is_zero2() or self.is_zero3(): - offload_devices = ["cpu", "nvme"] - if config_zero.get("offload_optimizer", {}).get("device") in offload_devices: - self.offload = True - if config_zero.get("offload_param", {}).get("device") in offload_devices: - self.offload = True + offload_devices_valid = set(["cpu", "nvme"]) + offload_devices = set( + [ + self.get_value("zero_optimization.offload_optimizer.device"), + self.get_value("zero_optimization.offload_param.device"), + ] + ) + if len(offload_devices & offload_devices_valid) > 0: + self._offload = True + + def find_config_node(self, ds_key_long): + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + ds_key = nodes.pop() + for node in nodes: + config = config.get(node) + if config is None: + return None, ds_key + + return config, ds_key + + def get_value(self, ds_key_long, default=None): + """ + Returns the set value or ``default`` if no value is set + """ + config, ds_key = self.find_config_node(ds_key_long) + if config is None: + return default + return config.get(ds_key, default) + + def is_true(self, ds_key_long): + """ + Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to + ask the very specific question of whether the value is set to :obj:`True` (and it's not set to :obj:`False` or + isn't set). + + """ + value = self.get_value(ds_key_long) + return False if value is None else bool(value) + + def is_false(self, ds_key_long): + """ + Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to + ask the very specific question of whether the value is set to :obj:`False` (and it's not set to :obj:`True` or + isn't set). + """ + value = self.get_value(ds_key_long) + return False if value is None else not bool(value) def is_zero2(self): - return self.stage == 2 + return self._stage == 2 def is_zero3(self): - return self.stage == 3 + return self._stage == 3 def is_offload(self): - return self.offload - - @staticmethod - def is_true(config, key): - if config is None: - return False - return bool(config.get(key)) + return self._offload class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): """ The ``HfTrainerDeepSpeedConfig`` object is meant to be created during ``TrainingArguments`` object creation and has the same lifespan as the latter. - """ def __init__(self, config_file_or_dict): super().__init__(config_file_or_dict) + self._dtype = torch.float16 self.mismatches = [] + def dtype(self): + return self._dtype + def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): """ A utility method that massages the config file and can optionally verify that the values match. @@ -121,16 +165,9 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ``trainer_config_finalize`` for one or more mismatches. """ - - config = self.config - - # find the config node of interest if it exists - nodes = ds_key_long.split(".") - ds_key = nodes.pop() - for node in nodes: - config = config.get(node) - if config is None: - return + config, ds_key = self.find_config_node(ds_key_long) + if config is None: + return if config.get(ds_key) == "auto": config[ds_key] = hf_val @@ -185,6 +222,13 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)") self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level") + # only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this + # whole config section is missing then the fallback is fp16 + if self.is_false("fp16.enabled"): + self._dtype = torch.float32 + # later there will be other dtypes besides just fp16 and fp32 + # also not quite sure what dtype should be under apex, defaulting to fp16 for now + def trainer_config_finalize(self, args, model, num_training_steps): """ This stage is run after we have the model and know num_training_steps. diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 98267b3a3e7..ec055db25bd 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -7,7 +7,7 @@ deps = { "cookiecutter": "cookiecutter==1.7.2", "dataclasses": "dataclasses", "datasets": "datasets", - "deepspeed": "deepspeed>=0.3.16", + "deepspeed": "deepspeed>=0.4.0", "docutils": "docutils==0.16.0", "fairscale": "fairscale>0.3", "faiss-cpu": "faiss-cpu", diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 73165c8fb67..894a49cd8c0 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -23,6 +23,8 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn +from transformers.deepspeed import is_deepspeed_zero3_enabled + from ...activations import ACT2FN from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput @@ -193,7 +195,17 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module): padding=config.num_conv_pos_embeddings // 2, groups=config.num_conv_pos_embedding_groups, ) - self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + deepspeed.zero.register_external_parameter(self, self.conv.weight_v) + deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + else: + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings) self.activation = ACT2FN[config.feat_extract_activation] @@ -615,15 +627,19 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) - if self.training and (dropout_probability < self.config.layerdrop): # skip the layer - layer_outputs = (None, None) - else: + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if getattr(self.config, "gradient_checkpointing", False) and self.training: # create gradient checkpointing function def create_custom_forward(module): @@ -643,6 +659,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ) hidden_states = layer_outputs[0] + if skip_the_layer: + layer_outputs = (None, None) + if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) @@ -680,7 +699,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): - torch.nn.init.kaiming_normal_(module.weight.data) + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + torch.nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + torch.nn.init.kaiming_normal_(module.weight.data) + else: + torch.nn.init.kaiming_normal_(module.weight.data) + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() @@ -1061,7 +1091,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) - log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1) + # ctc_loss doesn't support fp16 + log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = F.ctc_loss( diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8cd90ad5736..1e586729615 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -28,6 +28,7 @@ from typing import Iterator, Union from transformers import logging as transformers_logging +from .deepspeed import is_deepspeed_available from .file_utils import ( is_datasets_available, is_faiss_available, @@ -454,6 +455,16 @@ def require_soundfile(test_case): return test_case +def require_deepspeed(test_case): + """ + Decorator marking a test that requires deepspeed + """ + if not is_deepspeed_available(): + return unittest.skip("test requires deepspeed")(test_case) + else: + return test_case + + def get_gpu_count(): """ Return the number of available gpus (regardless of whether torch or tf is used) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 69fb09b9988..8303fef2d2a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1701,7 +1701,14 @@ class Trainer: """ for k, v in inputs.items(): if isinstance(v, torch.Tensor): - inputs[k] = v.to(self.args.device) + kwargs = dict(device=self.args.device) + if self.deepspeed and inputs[k].dtype != torch.int64: + # NLP models inputs are int64 and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype())) + + inputs[k] = v.to(**kwargs) if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 5f8cab68003..74a2928c3ec 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -32,6 +32,7 @@ from transformers.testing_utils import ( execute_subprocess_async, get_gpu_count, mockenv_context, + require_deepspeed, require_torch_gpu, require_torch_multi_gpu, slow, @@ -58,17 +59,6 @@ def load_json(path): return json.load(f) -# a candidate for testing_utils -def require_deepspeed(test_case): - """ - Decorator marking a test that requires deepspeed - """ - if not is_deepspeed_available(): - return unittest.skip("test requires deepspeed")(test_case) - else: - return test_case - - def require_deepspeed_aio(test_case): """ Decorator marking a test that requires deepspeed aio (nvme) @@ -404,15 +394,19 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): train_len = 64 a = b = 0.0 + kwargs = dict( + a=a, + b=b, + local_rank=0, + train_len=train_len, + fp16=True, + deepspeed=self.get_config_dict(stage), + ) + with mockenv_context(**self.dist_env_1_gpu): no_grad_accum_trainer = get_regression_trainer( - a=a, - b=b, - local_rank=0, - train_len=train_len, - fp16=True, - deepspeed=self.get_config_dict(stage), - per_device_train_batch_size=8, + **kwargs, + per_device_train_batch_size=16, gradient_accumulation_steps=1, ) no_grad_accum_result = no_grad_accum_trainer.train() @@ -424,14 +418,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): with mockenv_context(**self.dist_env_1_gpu): yes_grad_accum_trainer = get_regression_trainer( - a=a, - b=b, - local_rank=0, - train_len=train_len, - fp16=True, - deepspeed=self.get_config_dict(stage), + **kwargs, per_device_train_batch_size=4, - gradient_accumulation_steps=2, + gradient_accumulation_steps=4, ) yes_grad_accum_result = yes_grad_accum_trainer.train() yes_grad_accum_loss = yes_grad_accum_result.training_loss @@ -445,7 +434,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5) # see the note above how to get identical loss on a small bs - self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5) + self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2) def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage): # adapted from TrainerIntegrationCommon.check_saved_checkpoints