[Deepspeed Wav2vec2] integration (#11638)

* wip

* wip - but working with https://github.com/microsoft/DeepSpeed/pull/1044

* cleanup

* workaround

* working 5/8 modes

* solve fp32 distributed zero3

* style

* sync

* sync

* rework

* deprecation

* cleanup

* https://github.com/microsoft/DeepSpeed/pull/1044 pr was merged

* clean up

* add a guide

* more prose

* more prose

* fix

* more prose

* sub_group_size was too big

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor

* bug fix

* make the true check explicit

* new deepspeed release

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman 2021-06-08 12:32:03 -07:00 committed by GitHub
parent 32290d87f6
commit 11d86d3de4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 496 additions and 64 deletions

View File

@ -127,3 +127,60 @@ logs references and predictions. Using the Buckwalter format, text is also logge
`--max_duration_in_seconds="15"` filters out examples whose audio is longer than the specified limit,
which helps with capping GPU memory usage.
### DeepSpeed Integration
To learn how to deploy Deepspeed Integration please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration).
But to get started quickly all you need is to install:
```
pip install deepspeed
```
and then use the default configuration files in this directory:
* `ds_config_wav2vec2_zero2.json`
* `ds_config_wav2vec2_zero3.json`
Here are examples of how you can use DeepSpeed:
(edit the value for `--num_gpus` to match the number of GPUs you have)
ZeRO-2:
```
PYTHONPATH=../../../src deepspeed --num_gpus 2 \
run_asr.py \
--output_dir=output_dir --num_train_epochs=2 --per_device_train_batch_size=2 \
--per_device_eval_batch_size=2 --evaluation_strategy=steps --save_steps=500 --eval_steps=100 \
--logging_steps=5 --learning_rate=5e-4 --warmup_steps=3000 \
--model_name_or_path=patrickvonplaten/wav2vec2_tiny_random_robust \
--dataset_name=patrickvonplaten/librispeech_asr_dummy --dataset_config_name=clean \
--train_split_name=validation --validation_split_name=validation --orthography=timit \
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
--deepspeed ds_config_wav2vec2_zero2.json
```
For ZeRO-2 with more than 1 gpu you need to use (which is already in the example configuration file):
```
"zero_optimization": {
...
"find_unused_parameters": true,
...
}
```
ZeRO-3:
```
PYTHONPATH=../../../src deepspeed --num_gpus 2 \
run_asr.py \
--output_dir=output_dir --num_train_epochs=2 --per_device_train_batch_size=2 \
--per_device_eval_batch_size=2 --evaluation_strategy=steps --save_steps=500 --eval_steps=100 \
--logging_steps=5 --learning_rate=5e-4 --warmup_steps=3000 \
--model_name_or_path=patrickvonplaten/wav2vec2_tiny_random_robust \
--dataset_name=patrickvonplaten/librispeech_asr_dummy --dataset_config_name=clean \
--train_split_name=validation --validation_split_name=validation --orthography=timit \
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
--deepspeed ds_config_wav2vec2_zero3.json
```

View File

@ -0,0 +1,51 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"find_unused_parameters": true,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,57 @@
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@ -0,0 +1,185 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# XXX: we want transformers master here - in the absense of conftest manipulating sys.path:
# hack it in for now:
import sys
from pathlib import Path
git_repo_path = Path(__file__).resolve().parents[3] / "src"
sys.path.insert(1, str(git_repo_path))
import dataclasses # noqa
import io # noqa
import json # noqa
import os # noqa
import unittest # noqa
from copy import deepcopy # noqa
from parameterized import parameterized # noqa
from transformers import TrainingArguments, is_torch_available # noqa
from transformers.deepspeed import is_deepspeed_available # noqa
from transformers.file_utils import WEIGHTS_NAME # noqa
from transformers.testing_utils import ( # noqa
CaptureLogger,
ExtendSysPath,
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
mockenv_context,
require_deepspeed,
require_torch_gpu,
require_torch_multi_gpu,
slow,
)
from transformers.trainer_utils import set_seed # noqa
set_seed(42)
WAV2VEC2_TINY = "patrickvonplaten/wav2vec2_tiny_random_robust"
ZERO2 = "zero2"
ZERO3 = "zero3"
stages = [ZERO2, ZERO3]
@slow
@require_deepspeed
@require_torch_gpu
class TestDeepSpeedWav2Vec2(TestCasePlus):
@parameterized.expand(stages)
def test_fp32_non_distributed(self, stage):
self.run_and_check(
stage=stage,
distributed=False,
fp16=False,
)
@require_torch_multi_gpu
@parameterized.expand(stages)
def test_fp32_distributed(self, stage):
self.run_and_check(
stage=stage,
distributed=True,
fp16=False,
)
@parameterized.expand(stages)
def test_fp16_non_distributed(self, stage):
self.run_and_check(
stage=stage,
distributed=False,
fp16=True,
)
@require_torch_multi_gpu
@parameterized.expand(stages)
def test_fp16_distributed(self, stage):
self.run_and_check(
stage=stage,
distributed=True,
fp16=True,
)
def do_checks(self, output_dir):
# XXX: run_asr is premature and doesn't save any results
# so all we check for now is that the process didn't fail
pass
# XXX: need to do better validation beyond just that the run was successful
def run_and_check(
self,
stage,
model_name: str = WAV2VEC2_TINY,
eval_steps: int = 10,
distributed: bool = True,
quality_checks: bool = True,
fp16: bool = True,
):
output_dir = self.run_trainer(
stage=stage,
model_name=model_name,
eval_steps=eval_steps,
num_train_epochs=1,
distributed=distributed,
fp16=fp16,
)
self.do_checks(output_dir)
return output_dir
def run_trainer(
self,
stage: str,
model_name: str,
eval_steps: int = 10,
num_train_epochs: int = 1,
distributed: bool = True,
fp16: bool = True,
):
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
args = f"""
--model_name_or_path {model_name}
--dataset_name patrickvonplaten/librispeech_asr_dummy
--dataset_config_name clean
--train_split_name validation
--validation_split_name validation
--output_dir {output_dir}
--num_train_epochs {str(num_train_epochs)}
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--evaluation_strategy steps
--learning_rate 5e-4
--warmup_steps 8
--orthography timit
--preprocessing_num_workers 1
--group_by_length
--freeze_feature_extractor
--report_to none
--logging_steps 0
--save_steps 0
--eval_steps {eval_steps}
--report_to none
""".split()
if fp16:
args.extend(["--fp16"])
# currently ds_config_wav2vec2_zero.json requires "zero_optimization.find_unused_parameters": true,
# hence the separate config files
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_wav2vec2_{stage}.json".split()
script = [f"{self.examples_dir_str}/research_projects/wav2vec2/run_asr.py"]
launcher = self.get_launcher(distributed)
cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
execute_subprocess_async(cmd, env=self.get_env())
return output_dir
def get_launcher(self, distributed=False):
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
# - it won't be able to handle that
# 2. for now testing with just 2 gpus max (since some quality tests may give different
# results with mode gpus because we use very little data)
num_gpus = min(2, get_gpu_count()) if distributed else 1
return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split()

View File

@ -90,7 +90,7 @@ _deps = [
"cookiecutter==1.7.2",
"dataclasses",
"datasets",
"deepspeed>=0.3.16",
"deepspeed>=0.4.0",
"docutils==0.16.0",
"fairscale>0.3",
"faiss-cpu",

View File

@ -23,9 +23,13 @@ from copy import deepcopy
from functools import partialmethod
from .dependency_versions_check import dep_version_check
from .file_utils import is_torch_available
from .utils import logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
@ -70,46 +74,86 @@ class HfDeepSpeedConfig:
# zero stage - this is done as early as possible, before model is created, to allow
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
# during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc.
config_zero = config.get("zero_optimization", {})
self.stage = config_zero.get("stage", 0)
self._stage = self.get_value("zero_optimization.stage", -1)
# offload
self.offload = False
config_zero = config.get("zero_optimization", {})
self._offload = False
if self.is_zero2() or self.is_zero3():
offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
self.offload = True
if config_zero.get("offload_param", {}).get("device") in offload_devices:
self.offload = True
offload_devices_valid = set(["cpu", "nvme"])
offload_devices = set(
[
self.get_value("zero_optimization.offload_optimizer.device"),
self.get_value("zero_optimization.offload_param.device"),
]
)
if len(offload_devices & offload_devices_valid) > 0:
self._offload = True
def find_config_node(self, ds_key_long):
config = self.config
# find the config node of interest if it exists
nodes = ds_key_long.split(".")
ds_key = nodes.pop()
for node in nodes:
config = config.get(node)
if config is None:
return None, ds_key
return config, ds_key
def get_value(self, ds_key_long, default=None):
"""
Returns the set value or ``default`` if no value is set
"""
config, ds_key = self.find_config_node(ds_key_long)
if config is None:
return default
return config.get(ds_key, default)
def is_true(self, ds_key_long):
"""
Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
ask the very specific question of whether the value is set to :obj:`True` (and it's not set to :obj:`False` or
isn't set).
"""
value = self.get_value(ds_key_long)
return False if value is None else bool(value)
def is_false(self, ds_key_long):
"""
Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
ask the very specific question of whether the value is set to :obj:`False` (and it's not set to :obj:`True` or
isn't set).
"""
value = self.get_value(ds_key_long)
return False if value is None else not bool(value)
def is_zero2(self):
return self.stage == 2
return self._stage == 2
def is_zero3(self):
return self.stage == 3
return self._stage == 3
def is_offload(self):
return self.offload
@staticmethod
def is_true(config, key):
if config is None:
return False
return bool(config.get(key))
return self._offload
class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
"""
The ``HfTrainerDeepSpeedConfig`` object is meant to be created during ``TrainingArguments`` object creation and has
the same lifespan as the latter.
"""
def __init__(self, config_file_or_dict):
super().__init__(config_file_or_dict)
self._dtype = torch.float16
self.mismatches = []
def dtype(self):
return self._dtype
def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
"""
A utility method that massages the config file and can optionally verify that the values match.
@ -121,16 +165,9 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
``trainer_config_finalize`` for one or more mismatches.
"""
config = self.config
# find the config node of interest if it exists
nodes = ds_key_long.split(".")
ds_key = nodes.pop()
for node in nodes:
config = config.get(node)
if config is None:
return
config, ds_key = self.find_config_node(ds_key_long)
if config is None:
return
if config.get(ds_key) == "auto":
config[ds_key] = hf_val
@ -185,6 +222,13 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
# only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this
# whole config section is missing then the fallback is fp16
if self.is_false("fp16.enabled"):
self._dtype = torch.float32
# later there will be other dtypes besides just fp16 and fp32
# also not quite sure what dtype should be under apex, defaulting to fp16 for now
def trainer_config_finalize(self, args, model, num_training_steps):
"""
This stage is run after we have the model and know num_training_steps.

View File

@ -7,7 +7,7 @@ deps = {
"cookiecutter": "cookiecutter==1.7.2",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.3.16",
"deepspeed": "deepspeed>=0.4.0",
"docutils": "docutils==0.16.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",

View File

@ -23,6 +23,8 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
@ -193,7 +195,17 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module):
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
@ -615,15 +627,19 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if getattr(self.config, "gradient_checkpointing", False) and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
@ -643,6 +659,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
@ -680,7 +699,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight.data)
if is_deepspeed_zero3_enabled():
import deepspeed
if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
torch.nn.init.kaiming_normal_(module.weight.data)
else:
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
torch.nn.init.kaiming_normal_(module.weight.data)
else:
torch.nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
@ -1061,7 +1091,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
# ctc_loss doesn't support fp16
log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(

View File

@ -28,6 +28,7 @@ from typing import Iterator, Union
from transformers import logging as transformers_logging
from .deepspeed import is_deepspeed_available
from .file_utils import (
is_datasets_available,
is_faiss_available,
@ -454,6 +455,16 @@ def require_soundfile(test_case):
return test_case
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case
def get_gpu_count():
"""
Return the number of available gpus (regardless of whether torch or tf is used)

View File

@ -1701,7 +1701,14 @@ class Trainer:
"""
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
kwargs = dict(device=self.args.device)
if self.deepspeed and inputs[k].dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
inputs[k] = v.to(**kwargs)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past

View File

@ -32,6 +32,7 @@ from transformers.testing_utils import (
execute_subprocess_async,
get_gpu_count,
mockenv_context,
require_deepspeed,
require_torch_gpu,
require_torch_multi_gpu,
slow,
@ -58,17 +59,6 @@ def load_json(path):
return json.load(f)
# a candidate for testing_utils
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case
def require_deepspeed_aio(test_case):
"""
Decorator marking a test that requires deepspeed aio (nvme)
@ -404,15 +394,19 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
train_len = 64
a = b = 0.0
kwargs = dict(
a=a,
b=b,
local_rank=0,
train_len=train_len,
fp16=True,
deepspeed=self.get_config_dict(stage),
)
with mockenv_context(**self.dist_env_1_gpu):
no_grad_accum_trainer = get_regression_trainer(
a=a,
b=b,
local_rank=0,
train_len=train_len,
fp16=True,
deepspeed=self.get_config_dict(stage),
per_device_train_batch_size=8,
**kwargs,
per_device_train_batch_size=16,
gradient_accumulation_steps=1,
)
no_grad_accum_result = no_grad_accum_trainer.train()
@ -424,14 +418,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
with mockenv_context(**self.dist_env_1_gpu):
yes_grad_accum_trainer = get_regression_trainer(
a=a,
b=b,
local_rank=0,
train_len=train_len,
fp16=True,
deepspeed=self.get_config_dict(stage),
**kwargs,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
gradient_accumulation_steps=4,
)
yes_grad_accum_result = yes_grad_accum_trainer.train()
yes_grad_accum_loss = yes_grad_accum_result.training_loss
@ -445,7 +434,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5)
# see the note above how to get identical loss on a small bs
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints