sync LayerDrop for Wav2Vec2Encoder + tests (#12076)

This commit is contained in:
Stas Bekman 2021-06-09 05:21:03 -07:00 committed by GitHub
parent 82a2b76c95
commit d14e0af274
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 15 deletions

View File

@ -24,6 +24,7 @@ sys.path.insert(1, str(git_repo_path))
import dataclasses # noqa
import io # noqa
import itertools # noqa
import json # noqa
import os # noqa
import unittest # noqa
@ -50,48 +51,62 @@ from transformers.trainer_utils import set_seed # noqa
set_seed(42)
WAV2VEC2_TINY = "patrickvonplaten/wav2vec2_tiny_random_robust"
models = dict(base="patrickvonplaten/wav2vec2_tiny_random", robust="patrickvonplaten/wav2vec2_tiny_random_robust")
ZERO2 = "zero2"
ZERO3 = "zero3"
stages = [ZERO2, ZERO3]
def custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
return f"{func.__name__}_{param_based_name}"
# Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, models.keys()))
@slow
@require_deepspeed
@require_torch_gpu
class TestDeepSpeedWav2Vec2(TestCasePlus):
@parameterized.expand(stages)
def test_fp32_non_distributed(self, stage):
@parameterized.expand(params, name_func=custom_name_func)
def test_fp32_non_distributed(self, stage, model):
self.run_and_check(
stage=stage,
model=model,
distributed=False,
fp16=False,
)
@require_torch_multi_gpu
@parameterized.expand(stages)
def test_fp32_distributed(self, stage):
@parameterized.expand(params, name_func=custom_name_func)
def test_fp32_distributed(self, stage, model):
self.run_and_check(
stage=stage,
model=model,
distributed=True,
fp16=False,
)
@parameterized.expand(stages)
def test_fp16_non_distributed(self, stage):
@parameterized.expand(params, name_func=custom_name_func)
def test_fp16_non_distributed(self, stage, model):
self.run_and_check(
stage=stage,
model=model,
distributed=False,
fp16=True,
)
@require_torch_multi_gpu
@parameterized.expand(stages)
def test_fp16_distributed(self, stage):
@parameterized.expand(params, name_func=custom_name_func)
def test_fp16_distributed(self, stage, model):
self.run_and_check(
stage=stage,
model=model,
distributed=True,
fp16=True,
)
@ -104,14 +119,16 @@ class TestDeepSpeedWav2Vec2(TestCasePlus):
# XXX: need to do better validation beyond just that the run was successful
def run_and_check(
self,
stage,
model_name: str = WAV2VEC2_TINY,
stage: str,
model: str,
eval_steps: int = 10,
distributed: bool = True,
quality_checks: bool = True,
fp16: bool = True,
):
model_name = models[model]
output_dir = self.run_trainer(
stage=stage,
model_name=model_name,

View File

@ -548,15 +548,18 @@ class Wav2Vec2Encoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if getattr(self.config, "gradient_checkpointing", False) and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
@ -576,6 +579,9 @@ class Wav2Vec2Encoder(nn.Module):
)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)