mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[examples/speech-recognition] Add SpecAugment to run_speech_recognition_seq2seq.py (#21942)
* Add specaugment to run_speech_recognition_seq2seq.py * Remove useless argument: text_column * Fix quality * Update return_attention_mask condition * Update specaugment arguments only for whisper models * Remove SpecAugment arguments from ModelArguments, only leave default values for simplicity * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Update apply_spec_augment only for whisper models * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Rename return_attention_mask to forward_attention_mask to avoid confusion with wav2vec2 models --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
b427b263e2
commit
6192549c1f
@ -113,6 +113,12 @@ class ModelArguments:
|
||||
suppress_tokens: List[int] = field(
|
||||
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
|
||||
)
|
||||
apply_spec_augment: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -127,10 +133,6 @@ class DataTrainingArguments:
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
text_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
@ -227,10 +229,13 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
||||
The processor used for processing the data.
|
||||
decoder_start_token_id (`int`)
|
||||
The begin-of-sentence of the decoder.
|
||||
forward_attention_mask (`bool`)
|
||||
Whether to return attention_mask.
|
||||
"""
|
||||
|
||||
processor: Any
|
||||
decoder_start_token_id: int
|
||||
forward_attention_mask: bool
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
@ -241,6 +246,9 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
||||
|
||||
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
||||
|
||||
if self.forward_attention_mask:
|
||||
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
|
||||
|
||||
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
||||
|
||||
# replace padding with -100 to ignore loss correctly
|
||||
@ -367,6 +375,10 @@ def main():
|
||||
|
||||
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
|
||||
|
||||
# SpecAugment for whisper models
|
||||
if getattr(config, "model_type", None) == "whisper":
|
||||
config.update({"apply_spec_augment": model_args.apply_spec_augment})
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
@ -418,6 +430,12 @@ def main():
|
||||
text_column_name = data_args.text_column_name
|
||||
model_input_name = feature_extractor.model_input_names[0]
|
||||
do_lower_case = data_args.do_lower_case
|
||||
# if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
|
||||
forward_attention_mask = (
|
||||
getattr(config, "model_type", None) == "whisper"
|
||||
and getattr(config, "apply_spec_augment", False)
|
||||
and getattr(config, "mask_time_prob", 0) > 0
|
||||
)
|
||||
|
||||
if data_args.max_train_samples is not None:
|
||||
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
||||
@ -428,10 +446,14 @@ def main():
|
||||
def prepare_dataset(batch):
|
||||
# process audio
|
||||
sample = batch[audio_column_name]
|
||||
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
||||
inputs = feature_extractor(
|
||||
sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask
|
||||
)
|
||||
# process audio length
|
||||
batch[model_input_name] = inputs.get(model_input_name)[0]
|
||||
batch["input_length"] = len(sample["array"])
|
||||
if forward_attention_mask:
|
||||
batch["attention_mask"] = inputs.get("attention_mask")[0]
|
||||
|
||||
# process targets
|
||||
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
||||
@ -496,6 +518,7 @@ def main():
|
||||
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
|
||||
processor=processor,
|
||||
decoder_start_token_id=model.config.decoder_start_token_id,
|
||||
forward_attention_mask=forward_attention_mask,
|
||||
)
|
||||
|
||||
# 11. Initialize Trainer
|
||||
|
Loading…
Reference in New Issue
Block a user