[Whisper] Add SpecAugment (#21298)

* Return and rescale attention_mask

* Add SpecAugment to Whisper modeling

* Fix test

* Update docstring

* Add SpecAug related parameters to model config

* Add the _mask_input_features function to doc

* Fix quality

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove dev comments

* Add test

* Resolve conflict

* feat: mask {feature, time} prob fast tests

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
bofeng huang 2023-02-24 11:07:52 +01:00 committed by GitHub
parent 75bd49ff88
commit c8545d2a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 293 additions and 2 deletions

View File

@ -72,6 +72,7 @@ The original code can be found [here](https://github.com/openai/whisper).
[[autodoc]] WhisperModel [[autodoc]] WhisperModel
- forward - forward
- _mask_input_features
## WhisperForConditionalGeneration ## WhisperForConditionalGeneration

View File

@ -136,6 +136,35 @@ class WhisperConfig(PretrainedConfig):
begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`): begin_suppress_tokens (`List[int]`, *optional*, defaults to `[220,50256]`):
A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as A list containing tokens that will be supressed at the beginning of the sampling process. Initialized as
the token for `" "` (`blank_token_id`) and the `eos_token_id` the token for `" "` (`blank_token_id`) and the `eos_token_id`
apply_spec_augment (`bool`, *optional*, defaults to `False`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
Recognition](https://arxiv.org/abs/1904.08779).
mask_time_prob (`float`, *optional*, defaults to 0.05):
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
procecure generates `mask_time_prob*len(time_axis)/mask_time_length` independent masks over the axis. If
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
actual percentage of masked vectors. This is only relevant if `apply_spec_augment == True`.
mask_time_length (`int`, *optional*, defaults to 10):
Length of vector span along the time axis.
mask_time_min_masks (`int`, *optional*, defaults to 2),:
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
mask_time_min_masks''
mask_feature_prob (`float`, *optional*, defaults to 0.0):
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
True`.
mask_feature_length (`int`, *optional*, defaults to 10):
Length of vector span along the feature axis.
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
step, irrespectively of `mask_feature_prob`. Only relevant if
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
Example: Example:
@ -185,6 +214,13 @@ class WhisperConfig(PretrainedConfig):
eos_token_id=50256, eos_token_id=50256,
suppress_tokens=None, suppress_tokens=None,
begin_suppress_tokens=[220, 50256], begin_suppress_tokens=[220, 50256],
apply_spec_augment=False,
mask_time_prob=0.05,
mask_time_length=10,
mask_time_min_masks=2,
mask_feature_prob=0.0,
mask_feature_length=10,
mask_feature_min_masks=0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -208,6 +244,14 @@ class WhisperConfig(PretrainedConfig):
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions self.max_target_positions = max_target_positions
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_time_min_masks = mask_time_min_masks
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
self.mask_feature_min_masks = mask_feature_min_masks
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,

View File

@ -307,6 +307,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.n_samples,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
) )
# make sure list is in array format # make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1) input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
@ -318,6 +319,10 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
else: else:
padded_inputs["input_features"] = input_features padded_inputs["input_features"] = input_features
if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

View File

@ -19,6 +19,7 @@ import math
import random import random
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
@ -97,6 +98,126 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
class WhisperPositionalEmbedding(nn.Embedding): class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
@ -503,6 +624,14 @@ WHISPER_INPUTS_DOCSTRING = r"""
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
`[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
@ -999,11 +1128,55 @@ class WhisperModel(WhisperPreTrainedModel):
""" """
self.encoder._freeze_parameters() self.encoder._freeze_parameters()
def _mask_input_features(
self,
input_features: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return input_features
# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.size()
if self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
input_features[mask_time_indices] = 0
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
input_features[mask_feature_indices] = 0
return input_features
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_features: Optional[torch.LongTensor] = None, input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -1044,6 +1217,8 @@ class WhisperModel(WhisperPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None: if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
input_features, input_features,
head_mask=head_mask, head_mask=head_mask,
@ -1139,7 +1314,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_features: Optional[torch.LongTensor] = None, input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
@ -1193,6 +1369,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
outputs = self.model( outputs = self.model(
input_features, input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,

View File

@ -383,6 +383,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
expected_arg_names = [ expected_arg_names = [
"input_features", "input_features",
"attention_mask",
"decoder_input_ids", "decoder_input_ids",
"decoder_attention_mask", "decoder_attention_mask",
] ]
@ -909,6 +910,34 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
self.assertEqual(fx_keys, pt_keys) self.assertEqual(fx_keys, pt_keys)
self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class)
def test_mask_feature_prob(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.mask_feature_prob = 0.2
config.mask_feature_length = 2
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.train()
# forward pass
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
def test_mask_time_prob(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.mask_time_prob = 0.2
config.mask_time_length = 2
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.train()
# forward pass
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
@require_torch @require_torch
@require_torchaudio @require_torchaudio
@ -1289,3 +1318,38 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT) self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"
set_seed(0)
# Apply SpecAugment
model = WhisperModel.from_pretrained("openai/whisper-tiny", apply_spec_augment=True)
# Set model to training mode to enable SpecAugment
model.train()
model.to(torch_device)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
with torch.no_grad():
logits = model(
input_features,
decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
use_cache=False,
)
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
0.9362, -4.7105, 5.0879, 3.9642, 1.0013, -6.0096, 4.7285, -3.1847,
-0.8648, 1.9631, 6.2653, 3.6936, 0.3575, -4.5818, 3.0564, 7.8712,
2.9951, 0.6848, 9.9497, -2.6638, 1.1571, -6.8546, -1.4333, -7.7584,
1.1200, 3.9030, 4.4655, -4.4919, -1.1703, 9.6241
]
)
# fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))