[examples/speech-recognition] Add SpecAugment to run_speech_recognition_seq2seq.py (#21942)

* Add specaugment to run_speech_recognition_seq2seq.py

* Remove useless argument: text_column

* Fix quality

* Update return_attention_mask condition

* Update specaugment arguments only for whisper models

* Remove SpecAugment arguments from ModelArguments, only leave default values for simplicity

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update apply_spec_augment only for whisper models

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Rename return_attention_mask to forward_attention_mask to avoid confusion with wav2vec2 models

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
bofeng huang 2023-03-08 17:59:31 +01:00 committed by GitHub
parent b427b263e2
commit 6192549c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -113,6 +113,12 @@ class ModelArguments:
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
apply_spec_augment: bool = field(
default=False,
metadata={
"help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
},
)
@dataclass
@ -127,10 +133,6 @@ class DataTrainingArguments:
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
text_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
@ -227,10 +229,13 @@ class DataCollatorSpeechSeq2SeqWithPadding:
The processor used for processing the data.
decoder_start_token_id (`int`)
The begin-of-sentence of the decoder.
forward_attention_mask (`bool`)
Whether to return attention_mask.
"""
processor: Any
decoder_start_token_id: int
forward_attention_mask: bool
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
@ -241,6 +246,9 @@ class DataCollatorSpeechSeq2SeqWithPadding:
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
if self.forward_attention_mask:
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
@ -367,6 +375,10 @@ def main():
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
# SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment})
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
@ -418,6 +430,12 @@ def main():
text_column_name = data_args.text_column_name
model_input_name = feature_extractor.model_input_names[0]
do_lower_case = data_args.do_lower_case
# if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
forward_attention_mask = (
getattr(config, "model_type", None) == "whisper"
and getattr(config, "apply_spec_augment", False)
and getattr(config, "mask_time_prob", 0) > 0
)
if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
@ -428,10 +446,14 @@ def main():
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
inputs = feature_extractor(
sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask
)
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])
if forward_attention_mask:
batch["attention_mask"] = inputs.get("attention_mask")[0]
# process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
@ -496,6 +518,7 @@ def main():
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
forward_attention_mask=forward_attention_mask,
)
# 11. Initialize Trainer