mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Deepspeed Wav2vec2] integration (#11638)
* wip * wip - but working with https://github.com/microsoft/DeepSpeed/pull/1044 * cleanup * workaround * working 5/8 modes * solve fp32 distributed zero3 * style * sync * sync * rework * deprecation * cleanup * https://github.com/microsoft/DeepSpeed/pull/1044 pr was merged * clean up * add a guide * more prose * more prose * fix * more prose * sub_group_size was too big * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * refactor * bug fix * make the true check explicit * new deepspeed release Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
32290d87f6
commit
11d86d3de4
@ -127,3 +127,60 @@ logs references and predictions. Using the Buckwalter format, text is also logge
|
||||
|
||||
`--max_duration_in_seconds="15"` filters out examples whose audio is longer than the specified limit,
|
||||
which helps with capping GPU memory usage.
|
||||
|
||||
|
||||
### DeepSpeed Integration
|
||||
|
||||
To learn how to deploy Deepspeed Integration please refer to [this guide](https://huggingface.co/transformers/master/main_classes/deepspeed.html#deepspeed-trainer-integration).
|
||||
|
||||
But to get started quickly all you need is to install:
|
||||
```
|
||||
pip install deepspeed
|
||||
```
|
||||
and then use the default configuration files in this directory:
|
||||
|
||||
* `ds_config_wav2vec2_zero2.json`
|
||||
* `ds_config_wav2vec2_zero3.json`
|
||||
|
||||
Here are examples of how you can use DeepSpeed:
|
||||
|
||||
(edit the value for `--num_gpus` to match the number of GPUs you have)
|
||||
|
||||
ZeRO-2:
|
||||
|
||||
```
|
||||
PYTHONPATH=../../../src deepspeed --num_gpus 2 \
|
||||
run_asr.py \
|
||||
--output_dir=output_dir --num_train_epochs=2 --per_device_train_batch_size=2 \
|
||||
--per_device_eval_batch_size=2 --evaluation_strategy=steps --save_steps=500 --eval_steps=100 \
|
||||
--logging_steps=5 --learning_rate=5e-4 --warmup_steps=3000 \
|
||||
--model_name_or_path=patrickvonplaten/wav2vec2_tiny_random_robust \
|
||||
--dataset_name=patrickvonplaten/librispeech_asr_dummy --dataset_config_name=clean \
|
||||
--train_split_name=validation --validation_split_name=validation --orthography=timit \
|
||||
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
|
||||
--deepspeed ds_config_wav2vec2_zero2.json
|
||||
```
|
||||
|
||||
For ZeRO-2 with more than 1 gpu you need to use (which is already in the example configuration file):
|
||||
```
|
||||
"zero_optimization": {
|
||||
...
|
||||
"find_unused_parameters": true,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
ZeRO-3:
|
||||
|
||||
```
|
||||
PYTHONPATH=../../../src deepspeed --num_gpus 2 \
|
||||
run_asr.py \
|
||||
--output_dir=output_dir --num_train_epochs=2 --per_device_train_batch_size=2 \
|
||||
--per_device_eval_batch_size=2 --evaluation_strategy=steps --save_steps=500 --eval_steps=100 \
|
||||
--logging_steps=5 --learning_rate=5e-4 --warmup_steps=3000 \
|
||||
--model_name_or_path=patrickvonplaten/wav2vec2_tiny_random_robust \
|
||||
--dataset_name=patrickvonplaten/librispeech_asr_dummy --dataset_config_name=clean \
|
||||
--train_split_name=validation --validation_split_name=validation --orthography=timit \
|
||||
--preprocessing_num_workers=1 --group_by_length --freeze_feature_extractor --verbose_logging \
|
||||
--deepspeed ds_config_wav2vec2_zero3.json
|
||||
```
|
||||
|
@ -0,0 +1,51 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"find_unused_parameters": true,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"betas": "auto",
|
||||
"eps": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_fp16_weights_on_model_save": true
|
||||
},
|
||||
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
185
examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py
Normal file
185
examples/research_projects/wav2vec2/test_wav2vec2_deepspeed.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# XXX: we want transformers master here - in the absense of conftest manipulating sys.path:
|
||||
# hack it in for now:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
git_repo_path = Path(__file__).resolve().parents[3] / "src"
|
||||
sys.path.insert(1, str(git_repo_path))
|
||||
|
||||
import dataclasses # noqa
|
||||
import io # noqa
|
||||
import json # noqa
|
||||
import os # noqa
|
||||
import unittest # noqa
|
||||
from copy import deepcopy # noqa
|
||||
|
||||
from parameterized import parameterized # noqa
|
||||
from transformers import TrainingArguments, is_torch_available # noqa
|
||||
from transformers.deepspeed import is_deepspeed_available # noqa
|
||||
from transformers.file_utils import WEIGHTS_NAME # noqa
|
||||
from transformers.testing_utils import ( # noqa
|
||||
CaptureLogger,
|
||||
ExtendSysPath,
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
mockenv_context,
|
||||
require_deepspeed,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
)
|
||||
from transformers.trainer_utils import set_seed # noqa
|
||||
|
||||
|
||||
set_seed(42)
|
||||
|
||||
WAV2VEC2_TINY = "patrickvonplaten/wav2vec2_tiny_random_robust"
|
||||
|
||||
|
||||
ZERO2 = "zero2"
|
||||
ZERO3 = "zero3"
|
||||
stages = [ZERO2, ZERO3]
|
||||
|
||||
|
||||
@slow
|
||||
@require_deepspeed
|
||||
@require_torch_gpu
|
||||
class TestDeepSpeedWav2Vec2(TestCasePlus):
|
||||
@parameterized.expand(stages)
|
||||
def test_fp32_non_distributed(self, stage):
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
distributed=False,
|
||||
fp16=False,
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(stages)
|
||||
def test_fp32_distributed(self, stage):
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
distributed=True,
|
||||
fp16=False,
|
||||
)
|
||||
|
||||
@parameterized.expand(stages)
|
||||
def test_fp16_non_distributed(self, stage):
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
distributed=False,
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@parameterized.expand(stages)
|
||||
def test_fp16_distributed(self, stage):
|
||||
self.run_and_check(
|
||||
stage=stage,
|
||||
distributed=True,
|
||||
fp16=True,
|
||||
)
|
||||
|
||||
def do_checks(self, output_dir):
|
||||
# XXX: run_asr is premature and doesn't save any results
|
||||
# so all we check for now is that the process didn't fail
|
||||
pass
|
||||
|
||||
# XXX: need to do better validation beyond just that the run was successful
|
||||
def run_and_check(
|
||||
self,
|
||||
stage,
|
||||
model_name: str = WAV2VEC2_TINY,
|
||||
eval_steps: int = 10,
|
||||
distributed: bool = True,
|
||||
quality_checks: bool = True,
|
||||
fp16: bool = True,
|
||||
):
|
||||
|
||||
output_dir = self.run_trainer(
|
||||
stage=stage,
|
||||
model_name=model_name,
|
||||
eval_steps=eval_steps,
|
||||
num_train_epochs=1,
|
||||
distributed=distributed,
|
||||
fp16=fp16,
|
||||
)
|
||||
|
||||
self.do_checks(output_dir)
|
||||
|
||||
return output_dir
|
||||
|
||||
def run_trainer(
|
||||
self,
|
||||
stage: str,
|
||||
model_name: str,
|
||||
eval_steps: int = 10,
|
||||
num_train_epochs: int = 1,
|
||||
distributed: bool = True,
|
||||
fp16: bool = True,
|
||||
):
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
|
||||
args = f"""
|
||||
--model_name_or_path {model_name}
|
||||
--dataset_name patrickvonplaten/librispeech_asr_dummy
|
||||
--dataset_config_name clean
|
||||
--train_split_name validation
|
||||
--validation_split_name validation
|
||||
--output_dir {output_dir}
|
||||
--num_train_epochs {str(num_train_epochs)}
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 2
|
||||
--evaluation_strategy steps
|
||||
--learning_rate 5e-4
|
||||
--warmup_steps 8
|
||||
--orthography timit
|
||||
--preprocessing_num_workers 1
|
||||
--group_by_length
|
||||
--freeze_feature_extractor
|
||||
--report_to none
|
||||
--logging_steps 0
|
||||
--save_steps 0
|
||||
--eval_steps {eval_steps}
|
||||
--report_to none
|
||||
""".split()
|
||||
|
||||
if fp16:
|
||||
args.extend(["--fp16"])
|
||||
|
||||
# currently ds_config_wav2vec2_zero.json requires "zero_optimization.find_unused_parameters": true,
|
||||
# hence the separate config files
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_wav2vec2_{stage}.json".split()
|
||||
script = [f"{self.examples_dir_str}/research_projects/wav2vec2/run_asr.py"]
|
||||
launcher = self.get_launcher(distributed)
|
||||
|
||||
cmd = launcher + script + args + ds_args
|
||||
# keep for quick debug
|
||||
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
return output_dir
|
||||
|
||||
def get_launcher(self, distributed=False):
|
||||
# 1. explicitly set --num_nodes=1 just in case these tests end up run on a multi-node setup
|
||||
# - it won't be able to handle that
|
||||
# 2. for now testing with just 2 gpus max (since some quality tests may give different
|
||||
# results with mode gpus because we use very little data)
|
||||
num_gpus = min(2, get_gpu_count()) if distributed else 1
|
||||
return f"deepspeed --num_nodes 1 --num_gpus {num_gpus}".split()
|
2
setup.py
2
setup.py
@ -90,7 +90,7 @@ _deps = [
|
||||
"cookiecutter==1.7.2",
|
||||
"dataclasses",
|
||||
"datasets",
|
||||
"deepspeed>=0.3.16",
|
||||
"deepspeed>=0.4.0",
|
||||
"docutils==0.16.0",
|
||||
"fairscale>0.3",
|
||||
"faiss-cpu",
|
||||
|
@ -23,9 +23,13 @@ from copy import deepcopy
|
||||
from functools import partialmethod
|
||||
|
||||
from .dependency_versions_check import dep_version_check
|
||||
from .file_utils import is_torch_available
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -70,46 +74,86 @@ class HfDeepSpeedConfig:
|
||||
# zero stage - this is done as early as possible, before model is created, to allow
|
||||
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
|
||||
# during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc.
|
||||
config_zero = config.get("zero_optimization", {})
|
||||
self.stage = config_zero.get("stage", 0)
|
||||
self._stage = self.get_value("zero_optimization.stage", -1)
|
||||
|
||||
# offload
|
||||
self.offload = False
|
||||
config_zero = config.get("zero_optimization", {})
|
||||
self._offload = False
|
||||
if self.is_zero2() or self.is_zero3():
|
||||
offload_devices = ["cpu", "nvme"]
|
||||
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
|
||||
self.offload = True
|
||||
if config_zero.get("offload_param", {}).get("device") in offload_devices:
|
||||
self.offload = True
|
||||
offload_devices_valid = set(["cpu", "nvme"])
|
||||
offload_devices = set(
|
||||
[
|
||||
self.get_value("zero_optimization.offload_optimizer.device"),
|
||||
self.get_value("zero_optimization.offload_param.device"),
|
||||
]
|
||||
)
|
||||
if len(offload_devices & offload_devices_valid) > 0:
|
||||
self._offload = True
|
||||
|
||||
def find_config_node(self, ds_key_long):
|
||||
config = self.config
|
||||
|
||||
# find the config node of interest if it exists
|
||||
nodes = ds_key_long.split(".")
|
||||
ds_key = nodes.pop()
|
||||
for node in nodes:
|
||||
config = config.get(node)
|
||||
if config is None:
|
||||
return None, ds_key
|
||||
|
||||
return config, ds_key
|
||||
|
||||
def get_value(self, ds_key_long, default=None):
|
||||
"""
|
||||
Returns the set value or ``default`` if no value is set
|
||||
"""
|
||||
config, ds_key = self.find_config_node(ds_key_long)
|
||||
if config is None:
|
||||
return default
|
||||
return config.get(ds_key, default)
|
||||
|
||||
def is_true(self, ds_key_long):
|
||||
"""
|
||||
Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
|
||||
ask the very specific question of whether the value is set to :obj:`True` (and it's not set to :obj:`False` or
|
||||
isn't set).
|
||||
|
||||
"""
|
||||
value = self.get_value(ds_key_long)
|
||||
return False if value is None else bool(value)
|
||||
|
||||
def is_false(self, ds_key_long):
|
||||
"""
|
||||
Returns :obj:`True`/:obj:`False` only if the value is set, always :obj:`False` otherwise. So use this method to
|
||||
ask the very specific question of whether the value is set to :obj:`False` (and it's not set to :obj:`True` or
|
||||
isn't set).
|
||||
"""
|
||||
value = self.get_value(ds_key_long)
|
||||
return False if value is None else not bool(value)
|
||||
|
||||
def is_zero2(self):
|
||||
return self.stage == 2
|
||||
return self._stage == 2
|
||||
|
||||
def is_zero3(self):
|
||||
return self.stage == 3
|
||||
return self._stage == 3
|
||||
|
||||
def is_offload(self):
|
||||
return self.offload
|
||||
|
||||
@staticmethod
|
||||
def is_true(config, key):
|
||||
if config is None:
|
||||
return False
|
||||
return bool(config.get(key))
|
||||
return self._offload
|
||||
|
||||
|
||||
class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
|
||||
"""
|
||||
The ``HfTrainerDeepSpeedConfig`` object is meant to be created during ``TrainingArguments`` object creation and has
|
||||
the same lifespan as the latter.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config_file_or_dict):
|
||||
super().__init__(config_file_or_dict)
|
||||
self._dtype = torch.float16
|
||||
self.mismatches = []
|
||||
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
|
||||
"""
|
||||
A utility method that massages the config file and can optionally verify that the values match.
|
||||
@ -121,16 +165,9 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
|
||||
``trainer_config_finalize`` for one or more mismatches.
|
||||
|
||||
"""
|
||||
|
||||
config = self.config
|
||||
|
||||
# find the config node of interest if it exists
|
||||
nodes = ds_key_long.split(".")
|
||||
ds_key = nodes.pop()
|
||||
for node in nodes:
|
||||
config = config.get(node)
|
||||
if config is None:
|
||||
return
|
||||
config, ds_key = self.find_config_node(ds_key_long)
|
||||
if config is None:
|
||||
return
|
||||
|
||||
if config.get(ds_key) == "auto":
|
||||
config[ds_key] = hf_val
|
||||
@ -185,6 +222,13 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
|
||||
self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
|
||||
self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
|
||||
|
||||
# only if we have an explicit fp16.enabled = False then it's fp32, if it's True or this
|
||||
# whole config section is missing then the fallback is fp16
|
||||
if self.is_false("fp16.enabled"):
|
||||
self._dtype = torch.float32
|
||||
# later there will be other dtypes besides just fp16 and fp32
|
||||
# also not quite sure what dtype should be under apex, defaulting to fp16 for now
|
||||
|
||||
def trainer_config_finalize(self, args, model, num_training_steps):
|
||||
"""
|
||||
This stage is run after we have the model and know num_training_steps.
|
||||
|
@ -7,7 +7,7 @@ deps = {
|
||||
"cookiecutter": "cookiecutter==1.7.2",
|
||||
"dataclasses": "dataclasses",
|
||||
"datasets": "datasets",
|
||||
"deepspeed": "deepspeed>=0.3.16",
|
||||
"deepspeed": "deepspeed>=0.4.0",
|
||||
"docutils": "docutils==0.16.0",
|
||||
"fairscale": "fairscale>0.3",
|
||||
"faiss-cpu": "faiss-cpu",
|
||||
|
@ -23,6 +23,8 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
|
||||
@ -193,7 +195,17 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module):
|
||||
padding=config.num_conv_pos_embeddings // 2,
|
||||
groups=config.num_conv_pos_embedding_groups,
|
||||
)
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||||
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||||
else:
|
||||
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||
|
||||
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
@ -615,15 +627,19 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = np.random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
# create gradient checkpointing function
|
||||
def create_custom_forward(module):
|
||||
@ -643,6 +659,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if skip_the_layer:
|
||||
layer_outputs = (None, None)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
@ -680,7 +699,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
torch.nn.init.kaiming_normal_(module.weight.data)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed
|
||||
|
||||
if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
|
||||
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
|
||||
torch.nn.init.kaiming_normal_(module.weight.data)
|
||||
else:
|
||||
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
|
||||
torch.nn.init.kaiming_normal_(module.weight.data)
|
||||
else:
|
||||
torch.nn.init.kaiming_normal_(module.weight.data)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
@ -1061,7 +1091,8 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
target_lengths = labels_mask.sum(-1)
|
||||
flattened_targets = labels.masked_select(labels_mask)
|
||||
|
||||
log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
|
||||
# ctc_loss doesn't support fp16
|
||||
log_probs = F.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = F.ctc_loss(
|
||||
|
@ -28,6 +28,7 @@ from typing import Iterator, Union
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
from .deepspeed import is_deepspeed_available
|
||||
from .file_utils import (
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
@ -454,6 +455,16 @@ def require_soundfile(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_deepspeed(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires deepspeed
|
||||
"""
|
||||
if not is_deepspeed_available():
|
||||
return unittest.skip("test requires deepspeed")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||
|
@ -1701,7 +1701,14 @@ class Trainer:
|
||||
"""
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
kwargs = dict(device=self.args.device)
|
||||
if self.deepspeed and inputs[k].dtype != torch.int64:
|
||||
# NLP models inputs are int64 and those get adjusted to the right dtype of the
|
||||
# embedding. Other models such as wav2vec2's inputs are already float and thus
|
||||
# may need special handling to match the dtypes of the model
|
||||
kwargs.update(dict(dtype=self.args.hf_deepspeed_config.dtype()))
|
||||
|
||||
inputs[k] = v.to(**kwargs)
|
||||
|
||||
if self.args.past_index >= 0 and self._past is not None:
|
||||
inputs["mems"] = self._past
|
||||
|
@ -32,6 +32,7 @@ from transformers.testing_utils import (
|
||||
execute_subprocess_async,
|
||||
get_gpu_count,
|
||||
mockenv_context,
|
||||
require_deepspeed,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
@ -58,17 +59,6 @@ def load_json(path):
|
||||
return json.load(f)
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_deepspeed(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires deepspeed
|
||||
"""
|
||||
if not is_deepspeed_available():
|
||||
return unittest.skip("test requires deepspeed")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_deepspeed_aio(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires deepspeed aio (nvme)
|
||||
@ -404,15 +394,19 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
train_len = 64
|
||||
a = b = 0.0
|
||||
|
||||
kwargs = dict(
|
||||
a=a,
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=train_len,
|
||||
fp16=True,
|
||||
deepspeed=self.get_config_dict(stage),
|
||||
)
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
no_grad_accum_trainer = get_regression_trainer(
|
||||
a=a,
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=train_len,
|
||||
fp16=True,
|
||||
deepspeed=self.get_config_dict(stage),
|
||||
per_device_train_batch_size=8,
|
||||
**kwargs,
|
||||
per_device_train_batch_size=16,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
no_grad_accum_result = no_grad_accum_trainer.train()
|
||||
@ -424,14 +418,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
with mockenv_context(**self.dist_env_1_gpu):
|
||||
yes_grad_accum_trainer = get_regression_trainer(
|
||||
a=a,
|
||||
b=b,
|
||||
local_rank=0,
|
||||
train_len=train_len,
|
||||
fp16=True,
|
||||
deepspeed=self.get_config_dict(stage),
|
||||
**kwargs,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
gradient_accumulation_steps=4,
|
||||
)
|
||||
yes_grad_accum_result = yes_grad_accum_trainer.train()
|
||||
yes_grad_accum_loss = yes_grad_accum_result.training_loss
|
||||
@ -445,7 +434,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5)
|
||||
|
||||
# see the note above how to get identical loss on a small bs
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=5)
|
||||
self.assertAlmostEqual(no_grad_accum_loss, yes_grad_accum_loss, places=2)
|
||||
|
||||
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage):
|
||||
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
|
||||
|
Loading…
Reference in New Issue
Block a user