mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Making CTC training example more general (#28582)
* add w2v2bert compatibility * Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
186aa6befe
commit
772307be76
@ -132,10 +132,17 @@ class ModelArguments:
|
||||
ctc_loss_reduction: Optional[str] = field(
|
||||
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
||||
)
|
||||
ctc_zero_infinity: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
|
||||
" occur when the inputs are too short to be aligned to the targets."
|
||||
},
|
||||
)
|
||||
add_adapter: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very"
|
||||
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
|
||||
"useful to downsample the output length."
|
||||
},
|
||||
)
|
||||
@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding:
|
||||
padding: Union[bool, str] = "longest"
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
pad_to_multiple_of_labels: Optional[int] = None
|
||||
feature_extractor_input_name: Optional[str] = "input_values"
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lengths and need
|
||||
# different padding methods
|
||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
input_features = [
|
||||
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
|
||||
]
|
||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
batch = self.processor.pad(
|
||||
@ -606,6 +616,7 @@ def main():
|
||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||
"layerdrop": model_args.layerdrop,
|
||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||
"ctc_zero_infinity": model_args.ctc_zero_infinity,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"vocab_size": len(tokenizer),
|
||||
"activation_dropout": model_args.activation_dropout,
|
||||
@ -643,6 +654,7 @@ def main():
|
||||
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
||||
audio_column_name = data_args.audio_column_name
|
||||
num_workers = data_args.preprocessing_num_workers
|
||||
feature_extractor_input_name = feature_extractor.model_input_names[0]
|
||||
|
||||
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
||||
phoneme_language = data_args.phoneme_language
|
||||
@ -654,8 +666,9 @@ def main():
|
||||
sample = batch[audio_column_name]
|
||||
|
||||
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
||||
batch["input_values"] = inputs.input_values[0]
|
||||
batch["input_length"] = len(batch["input_values"])
|
||||
batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
|
||||
# take length of raw audio waveform
|
||||
batch["input_length"] = len(sample["array"].squeeze())
|
||||
|
||||
# encode targets
|
||||
additional_kwargs = {}
|
||||
@ -736,7 +749,9 @@ def main():
|
||||
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
|
||||
|
||||
# Instantiate custom data collator
|
||||
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
||||
data_collator = DataCollatorCTCWithPadding(
|
||||
processor=processor, feature_extractor_input_name=feature_extractor_input_name
|
||||
)
|
||||
|
||||
# Initialize Trainer
|
||||
trainer = Trainer(
|
||||
|
Loading…
Reference in New Issue
Block a user