mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 06:10:04 +06:00
282 lines
10 KiB
Python
Executable File
282 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import datasets
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from packaging import version
|
|
|
|
import soundfile as sf
|
|
from transformers import (
|
|
HfArgumentParser,
|
|
Trainer,
|
|
TrainingArguments,
|
|
Wav2Vec2ForCTC,
|
|
Wav2Vec2Processor,
|
|
is_apex_available,
|
|
)
|
|
|
|
|
|
if is_apex_available():
|
|
from apex import amp
|
|
|
|
|
|
if version.parse(torch.__version__) >= version.parse("1.6"):
|
|
_is_native_amp_available = True
|
|
from torch.cuda.amp import autocast
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
|
"""
|
|
|
|
model_name_or_path: str = field(
|
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
|
)
|
|
freeze_feature_extractor: Optional[bool] = field(
|
|
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
|
|
Using `HfArgumentParser` we can turn this class
|
|
into argparse arguments to be able to specify them on
|
|
the command line.
|
|
"""
|
|
|
|
dataset_name: str = field(
|
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
|
)
|
|
dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
|
)
|
|
train_split_name: Optional[str] = field(
|
|
default="train",
|
|
metadata={
|
|
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
|
},
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
|
)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorCTCWithPadding:
|
|
"""
|
|
Data collator that will dynamically pad the inputs received.
|
|
Args:
|
|
processor (:class:`~transformers.Wav2Vec2Processor`)
|
|
The processor used for proccessing the data.
|
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
|
among:
|
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
|
sequence if provided).
|
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
|
maximum acceptable input length for the model if that argument is not provided.
|
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
|
different lengths).
|
|
max_length (:obj:`int`, `optional`):
|
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
|
max_length_labels (:obj:`int`, `optional`):
|
|
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
|
If set will pad the sequence to a multiple of the provided value.
|
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
|
7.5 (Volta).
|
|
"""
|
|
|
|
processor: Wav2Vec2Processor
|
|
padding: Union[bool, str] = True
|
|
max_length: Optional[int] = None
|
|
max_length_labels: Optional[int] = None
|
|
pad_to_multiple_of: Optional[int] = None
|
|
pad_to_multiple_of_labels: Optional[int] = None
|
|
|
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
|
# split inputs and labels since they have to be of different lenghts and need
|
|
# different padding methods
|
|
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
|
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
|
|
|
batch = self.processor.pad(
|
|
input_features,
|
|
padding=self.padding,
|
|
max_length=self.max_length,
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
return_tensors="pt",
|
|
)
|
|
with self.processor.as_target_processor():
|
|
labels_batch = self.processor.pad(
|
|
label_features,
|
|
padding=self.padding,
|
|
max_length=self.max_length_labels,
|
|
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# replace padding with -100 to ignore loss correctly
|
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
|
|
|
batch["labels"] = labels
|
|
|
|
return batch
|
|
|
|
|
|
class CTCTrainer(Trainer):
|
|
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
|
"""
|
|
Perform a training step on a batch of inputs.
|
|
|
|
Subclass and override to inject custom behavior.
|
|
|
|
Args:
|
|
model (:obj:`nn.Module`):
|
|
The model to train.
|
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
The inputs and targets of the model.
|
|
|
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
|
|
Return:
|
|
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
|
"""
|
|
|
|
model.train()
|
|
inputs = self._prepare_inputs(inputs)
|
|
|
|
if self.use_amp:
|
|
with autocast():
|
|
loss = self.compute_loss(model, inputs)
|
|
else:
|
|
loss = self.compute_loss(model, inputs)
|
|
|
|
if self.args.n_gpu > 1:
|
|
if model.module.config.ctc_loss_reduction == "mean":
|
|
loss = loss.mean()
|
|
elif model.module.config.ctc_loss_reduction == "sum":
|
|
loss = loss.sum() / (inputs["labels"] >= 0).sum()
|
|
else:
|
|
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
|
|
|
|
if self.args.gradient_accumulation_steps > 1:
|
|
loss = loss / self.args.gradient_accumulation_steps
|
|
|
|
if self.use_amp:
|
|
self.scaler.scale(loss).backward()
|
|
elif self.use_apex:
|
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
elif self.deepspeed:
|
|
self.deepspeed.backward(loss)
|
|
else:
|
|
loss.backward()
|
|
|
|
return loss.detach()
|
|
|
|
|
|
def main():
|
|
# See all possible arguments in src/transformers/training_args.py
|
|
# or by passing the --help flag to this script.
|
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
|
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
|
processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
|
|
|
train_dataset = datasets.load_dataset(
|
|
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
|
|
)
|
|
val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation")
|
|
|
|
wer_metric = datasets.load_metric("wer")
|
|
|
|
def map_to_array(batch):
|
|
speech_array, sampling_rate = sf.read(batch["file"])
|
|
batch["speech"] = speech_array
|
|
batch["sampling_rate"] = sampling_rate
|
|
return batch
|
|
|
|
train_dataset = train_dataset.map(map_to_array, remove_columns=["file"])
|
|
val_dataset = val_dataset.map(map_to_array, remove_columns=["file"])
|
|
|
|
def prepare_dataset(batch):
|
|
# check that all files have the correct sampling rate
|
|
assert (
|
|
len(set(batch["sampling_rate"])) == 1
|
|
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
|
|
|
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
|
with processor.as_target_processor():
|
|
batch["labels"] = processor(batch["text"]).input_ids
|
|
return batch
|
|
|
|
train_dataset = train_dataset.map(
|
|
prepare_dataset,
|
|
batch_size=training_args.per_device_train_batch_size,
|
|
batched=True,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
val_dataset = val_dataset.map(
|
|
prepare_dataset,
|
|
batch_size=training_args.per_device_train_batch_size,
|
|
batched=True,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
|
|
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
|
|
|
def compute_metrics(pred):
|
|
pred_logits = pred.predictions
|
|
pred_ids = np.argmax(pred_logits, axis=-1)
|
|
|
|
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
|
|
|
|
pred_str = processor.batch_decode(pred_ids)
|
|
# we do not want to group tokens when computing the metrics
|
|
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
|
|
|
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
|
|
|
return {"wer": wer}
|
|
|
|
if model_args.freeze_feature_extractor:
|
|
model.freeze_feature_extractor()
|
|
|
|
trainer = CTCTrainer(
|
|
model=model,
|
|
data_collator=data_collator,
|
|
args=training_args,
|
|
compute_metrics=compute_metrics,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=val_dataset,
|
|
tokenizer=processor.feature_extractor,
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|