[Speech Examples] Add pytorch speech pretraining (#13877)

* adapt wav2vec2

* add example

* add files

* adapt

* remove bogus file

* Apply suggestions from code review

* adapt files more

* upload changes

* del old files

* up

* up

* up

* up

* up

* correct gradient checkpoitning

* add readme

* finish

* finish

* up

* more fixes

* up

* up

* add demo run to readme

* up
This commit is contained in:
Patrick von Platen 2021-10-12 00:46:32 +02:00 committed by GitHub
parent 3499728dc4
commit d45fc7da3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1196 additions and 183 deletions

View File

@ -3,6 +3,7 @@ scikit-learn
seqeval
psutil
sacrebleu >= 1.4.12
accelerate >= 0.5.0
rouge-score
tensorflow_datasets
matplotlib

View File

@ -0,0 +1,124 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Speech Recognition Pre-Training
## Wav2Vec2 Speech Pre-Training
The script [`run_speech_wav2vec2_pretraining_no_trainer.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py) can be used to pre-train a [Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html?highlight=wav2vec2) model from scratch.
In the script [`run_speech_wav2vec2_pretraining_no_trainer`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py), a Wav2Vec2 model is pre-trained on audio data alone using [Wav2Vec2's contrastive loss objective](https://arxiv.org/abs/2006.11477).
The following examples show how to fine-tune a `"base"`-sized Wav2Vec2 model as well as a `"large"`-sized Wav2Vec2 model using [`accelerate`](https://github.com/huggingface/accelerate).
---
**NOTE 1**
Wav2Vec2's pre-training is known to be quite unstable.
It is advised to do a couple of test runs with a smaller dataset,
*i.e.* `--dataset_config_names clean clean`, `--dataset_split_names validation test`
to find good hyper-parameters for `learning_rate`, `batch_size`, `num_warmup_steps`,
and the optimizer.
A good metric to observe during training is the gradient norm which should ideally be between 0.5 and 2.
---
---
**NOTE 2**
When training a model on large datasets it is recommended to run the data preprocessing
in a first run in a **non-distributed** mode via `--preprocessing_only` so that
when running the model in **distributed** mode in a second step the preprocessed data
can easily be loaded on each distributed device.
---
### Demo
In this demo run we pre-train a `"base-sized"` Wav2Vec2 model simply only on the validation
and test data of [librispeech_asr](https://huggingface.co/datasets/librispeech_asr).
The demo is run on two Titan RTX (24 GB RAM each). In case you have less RAM available
per device, consider reducing `--batch_size` and/or the `--max_duration_in_seconds`.
```bash
accelerate launch run_wav2vec2_pretraining_no_trainer.py \
--dataset_name="librispeech_asr" \
--dataset_config_names clean clean \
--dataset_split_names validation test \
--model_name_or_path="patrickvonplaten/wav2vec2-base-v2" \
--output_dir="./wav2vec2-pretrained-demo" \
--max_train_steps="20000" \
--num_warmup_steps="32000" \
--gradient_accumulation_steps="8" \
--learning_rate="0.005" \
--weight_decay="0.01" \
--max_duration_in_seconds="20.0" \
--min_duration_in_seconds="2.0" \
--logging_steps="1" \
--saving_steps="10000" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--adam_epsilon="1e-06" \
--gradient_checkpointing \
```
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/wav2vec2-pretrained-demo/reports/Wav2Vec2-PreTraining-Demo-Run--VmlldzoxMDk3MjAw?accessToken=oa05s1y57lizo2ocxy3k01g6db1u4pt8m6ur2n8nl4cb0ug02ms2cw313kb8ruch).
### Base
TODO (currently running...)
### Large
To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60),
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
```bash
accelerate launch run_pretrain_no_trainer.py \
--dataset_name=librispeech_asr \
--dataset_config_names clean clean other \
--dataset_split_names train.100 train.360 train.500 \
--output_dir=./test \
--max_train_steps=200000 \
--num_warmup_steps=32000 \
--gradient_accumulation_steps=8 \
--learning_rate=0.001 \
--weight_decay=0.01 \
--max_duration_in_seconds=20.0 \
--min_duration_in_seconds=2.0 \
--model_name_or_path=./
--logging_steps=1 \
--saving_steps=10000 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=4 \
--adam_beta1=0.9 \
--adam_beta2=0.98 \
--adam_epsilon=1e-06 \
--gradient_checkpointing \
```
The experiment was run on 8 GPU V100 (16 GB RAM each) for 7 days.
In case you have more than 8 GPUs available for a higher effective `batch_size`,
it is recommended to increase the `learning_rate` to `0.005` for faster convergence.
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/pretraining-wav2vec2/reports/Wav2Vec2-Large--VmlldzoxMTAwODM4?accessToken=wm3qzcnldrwsa31tkvf2pdmilw3f63d4twtffs86ou016xjbyilh55uoi3mo1qzc) and the checkpoint pretrained for 120,000 steps can be accessed [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-repro-960h-libri-120k-steps)

View File

@ -0,0 +1,4 @@
datasets >= 1.12.0
torch >= 1.5
torchaudio
accelerate >= 0.5.0

View File

@ -0,0 +1,700 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
""" Pre-Training a 🤗 Wav2Vec2 model on unlabeled audio data """
import argparse
import logging
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union
import datasets
import torch
import torchaudio
from datasets import DatasetDict, concatenate_datasets, load_dataset
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from huggingface_hub import Repository
from transformers import (
AdamW,
SchedulerType,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForPreTraining,
get_scheduler,
is_wandb_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help="The name of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_config_names",
nargs="+",
type=str,
required=True,
help="The configuration names of the dataset to use (via the datasets library).",
)
parser.add_argument(
"--dataset_split_names",
nargs="+",
type=str,
required=True,
help="The names of the training data set splits to use (via the datasets library).",
)
parser.add_argument(
"--preprocessing_num_workers",
type=int,
default=None,
help="The number of processes to use for the preprocessing.",
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument(
"--preprocessing_only",
action="store_true",
help="Only run the preprocessing script to be cached for future use",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="Where do you want to store the pretrained models downloaded from huggingface.co",
)
parser.add_argument(
"--validation_split_percentage",
type=int,
default=1,
help="Percentage of training data that should be used for validation if no validation is present in dataset.",
)
parser.add_argument(
"--logging_steps",
type=int,
default=500,
help="Number of steps between each logging",
)
parser.add_argument(
"--saving_steps",
type=int,
default=500,
help="Number of steps between each logging",
)
parser.add_argument(
"--audio_column_name",
type=str,
default="file",
help="Column in the dataset that contains speech file path. Defaults to 'file'",
)
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
)
parser.add_argument(
"--config_name",
type=str,
default=None,
help="Pretrained config name or path if not the same as model_name",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="If True, use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="linear",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
parser.add_argument(
"--max_gumbel_temperature",
type=float,
default=2.0,
help="Maximum temperature for gumbel softmax.",
)
parser.add_argument(
"--min_gumbel_temperature",
type=float,
default=0.5,
help="Minimum temperature for gumbel softmax.",
)
parser.add_argument(
"--gumbel_temperature_decay", type=float, default=0.999995, help="Decay of gumbel temperature during training."
)
parser.add_argument(
"--max_duration_in_seconds",
type=float,
default=5.0,
help="Filter out audio files that are longer than `max_duration_in_seconds` seconds",
)
parser.add_argument(
"--min_duration_in_seconds",
type=float,
default=3.0,
help="Filter out audio files that are shorter than `min_duration_in_seconds` seconds",
)
parser.add_argument(
"--pad_to_multiple_of",
type=int,
default=None,
help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="Beta1 for AdamW optimizer",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="Beta2 for AdamW optimizer",
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-8,
help="Epsilon for AdamW optimizer",
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args()
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
@dataclass
class DataCollatorForWav2Vec2Pretraining:
"""
Data collator that will dynamically pad the inputs received and prepare masked indices
for self-supervised pretraining.
Args:
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
The Wav2Vec2 model used for pretraining. The data collator needs to have access
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
model: Wav2Vec2ForPreTraining
feature_extractor: Wav2Vec2FeatureExtractor
padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# reformat list to dict and set to pytorch format
batch = self.feature_extractor.pad(
features,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
device = batch["input_values"].device
batch_size = batch["input_values"].shape[0]
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# make sure that no loss is computed on padded inputs
if batch.get("attention_mask") is not None:
# compute real output lengths according to convolution formula
batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
mask_indices_seq_length, batch["attention_mask"]
)
features_shape = (batch_size, mask_indices_seq_length)
# sample randomly masked indices
mask_time_indices = _compute_mask_indices(
features_shape,
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
attention_mask=batch.get("sub_attention_mask"),
)
# sample negative indices
sampled_negative_indices = _sample_negative_indices(
features_shape,
self.model.config.num_negatives,
mask_time_indices=mask_time_indices,
)
batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device)
batch["sampled_negative_indices"] = torch.tensor(sampled_negative_indices, dtype=torch.long, device=device)
return batch
def multiply_grads(params, c):
"""Multiplies grads by a constant *c*."""
for p in params:
if p.grad is not None:
if torch.is_tensor(c):
c = c.to(p.grad.device)
p.grad.data.mul_(c)
def get_grad_norm(params, scale=1):
"""Compute grad norm given a gradient scale."""
total_norm = 0.0
for p in params:
if p.grad is not None:
param_norm = (p.grad.detach().data / scale).norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
def main():
# See all possible arguments in src/transformers/args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
logger.info(accelerator.state)
# Setup logging, we only want one process per machine to log things on the screen.
# accelerator.is_local_main_process is only True for one process per machine.
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
# set up weights and biases if available
if is_wandb_available():
import wandb
wandb.init(project=args.output_dir.split("/")[-1])
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub and not args.preprocessing_only:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# 1. Download and create train, validation dataset
# We load all dataset configuration and datset split pairs passed in
# ``args.dataset_config_names`` and ``args.dataset_split_names``
datasets_splits = []
for dataset_config_name, train_split_name in zip(args.dataset_config_names, args.dataset_split_names):
# load dataset
dataset_split = load_dataset(
args.dataset_name, dataset_config_name, split=train_split_name, cache_dir=args.cache_dir
)
datasets_splits.append(dataset_split)
# Next, we concatenate all configurations and splits into a single training dataset
raw_datasets = DatasetDict()
if len(datasets_splits) > 1:
raw_datasets["train"] = concatenate_datasets(datasets_splits).shuffle(seed=args.seed)
else:
raw_datasets["train"] = datasets_splits[0]
# Take ``args.validation_split_percentage`` from the training dataset for the validation_split_percentage
num_validation_samples = raw_datasets["train"].num_rows * args.validation_split_percentage // 100
if num_validation_samples == 0:
raise ValueError(
"`args.validation_split_percentage` is less than a single sample "
f"for {len(raw_datasets['train'])} training samples. Increase "
"`args.num_validation_split_percentage`. "
)
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
# 2. Preprocess audio: load, resample, normalize and truncate
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
# only normalized-inputs-training is supported
if not feature_extractor.do_normalize:
raise ValueError(
"Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``"
)
# set max & min audio length in number of samples
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
resampler = None
if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3":
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate)
def prepare_dataset(batch):
speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name])
speech_array = speech_array.squeeze()
# if necessary resample audio
if resampler is not None:
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
speech_array = resampler(speech_array)
sampling_rate = resampler.new_freq
speech_array = speech_array.numpy()
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True)
batch["input_values"] = inputs.input_values[0]
return batch
# load audio files into numpy arrays
with accelerator.main_process_first():
vectorized_datasets = raw_datasets.map(
prepare_dataset,
num_proc=args.preprocessing_num_workers,
remove_columns=raw_datasets["train"].column_names,
load_from_cache_file=not args.overwrite_cache,
)
vectorized_datasets = vectorized_datasets.filter(
lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if args.preprocessing_only:
return
# 3. Load model
config = Wav2Vec2Config.from_pretrained(args.model_name_or_path)
# pretraining is only supported for "newer" stable layer norm architecture
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
)
# initialize random model
model = Wav2Vec2ForPreTraining(config)
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
# 4. Define data collator, optimizer and scheduler
data_collator = DataCollatorForWav2Vec2Pretraining(
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=args.pad_to_multiple_of
)
train_dataloader = DataLoader(
vectorized_datasets["train"],
shuffle=True,
collate_fn=data_collator,
batch_size=args.per_device_train_batch_size,
)
eval_dataloader = DataLoader(
vectorized_datasets["validation"], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
)
# Optimizer
optimizer = AdamW(
list(model.parameters()),
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],
eps=args.adam_epsilon,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# 5. Train
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
completed_steps = 0
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
for epoch in range(args.num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
# compute num of losses
num_losses = batch["mask_time_indices"].sum()
sub_attention_mask = batch.pop("sub_attention_mask", None)
sub_attention_mask = (
sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"])
)
percent_masked = num_losses / sub_attention_mask.sum()
# forward
outputs = model(**batch)
# divide loss by gradient accumulation steps since gradients
# are accumulated for multiple backward passes in PyTorch
loss = outputs.loss / args.gradient_accumulation_steps
accelerator.backward(loss)
# make sure that `num_losses` is summed for distributed training
# and average gradients over losses of all devices
if accelerator.state.num_processes > 1:
num_losses = accelerator.gather(num_losses).sum()
gradient_multiplier = accelerator.state.num_processes / num_losses
multiply_grads(model.module.parameters(), gradient_multiplier)
else:
multiply_grads(model.parameters(), 1 / num_losses)
# update step
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
# compute grad norm for monitoring
scale = (
accelerator.scaler._scale.item()
if hasattr(accelerator, "scaler") and accelerator.scaler is not None
else 1
)
if accelerator.state.num_processes > 1:
grad_norm = get_grad_norm(model.module.parameters(), scale)
else:
grad_norm = get_grad_norm(model.parameters(), scale)
# update parameters
optimizer.step()
optimizer.zero_grad()
if not accelerator.optimizer_step_was_skipped:
lr_scheduler.step()
elif accelerator.is_local_main_process:
progress_bar.write(
"Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..."
)
# update gumbel temperature
gumbel_temperature = max(
args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps,
args.min_gumbel_temperature,
)
if hasattr(model, "module"):
model.module.set_gumbel_temperature(gumbel_temperature)
else:
model.set_gumbel_temperature(gumbel_temperature)
progress_bar.update(1)
completed_steps += 1
# 6. Log all results
if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0:
loss.detach()
outputs.contrastive_loss.detach()
outputs.diversity_loss.detach()
if accelerator.state.num_processes > 1:
loss = accelerator.gather(loss).sum()
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum()
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum()
percent_masked = accelerator.gather(percent_masked).sum()
train_logs = {
"loss": (loss * args.gradient_accumulation_steps) / num_losses,
"constrast_loss": outputs.contrastive_loss / num_losses,
"div_loss": outputs.diversity_loss / num_losses,
"%_mask_idx": percent_masked / accelerator.num_processes,
"ppl": outputs.codevector_perplexity,
"lr": torch.tensor(optimizer.param_groups[0]["lr"]),
"temp": torch.tensor(gumbel_temperature),
"grad_norm": torch.tensor(grad_norm),
}
log_str = ""
for k, v in train_logs.items():
log_str += "| {}: {:.3e}".format(k, v.item())
if accelerator.is_local_main_process:
progress_bar.write(log_str)
if is_wandb_available():
wandb.log(train_logs)
# save model every `args.saving_steps` steps
if (step + 1) % (args.gradient_accumulation_steps * args.saving_steps) == 0:
if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process:
repo.push_to_hub(commit_message=f"Training in progress step {completed_steps}", blocking=False)
# if completed steps > `args.max_train_steps` stop
if completed_steps >= args.max_train_steps:
break
# 7. Validate!
model.eval()
# init logs
val_logs = {
"val_loss": 0,
"val_contrastive_loss": 0,
"val_diversity_loss": 0,
"val_num_losses": 0,
}
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
batch.pop("sub_attention_mask", None)
outputs = model(**batch)
val_logs["val_loss"] += outputs.loss
val_logs["val_contrastive_loss"] += outputs.contrastive_loss
val_logs["val_diversity_loss"] += outputs.diversity_loss
val_logs["val_num_losses"] += batch["mask_time_indices"].sum()
# sum over devices in multi-processing
if accelerator.num_processes > 1:
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()}
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}
log_str = ""
for k, v in val_logs.items():
log_str += "| {}: {:.3e}".format(k, v.item())
if accelerator.is_local_main_process:
progress_bar.write(log_str)
if is_wandb_available():
wandb.log(val_logs)
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training")
if __name__ == "__main__":
main()

View File

@ -23,6 +23,7 @@ from unittest.mock import patch
import torch
from transformers import Wav2Vec2ForPreTraining
from transformers.file_utils import is_apex_available
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
@ -41,6 +42,7 @@ SRC_DIRS = [
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
]
]
sys.path.extend(SRC_DIRS)
@ -59,6 +61,7 @@ if SRC_DIRS is not None:
import run_summarization
import run_swag
import run_translation
import run_wav2vec2_pretraining_no_trainer
logging.basicConfig(level=logging.DEBUG)
@ -447,3 +450,32 @@ class ExamplesTests(TestCasePlus):
run_audio_classification.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_wav2vec2_pretraining(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_wav2vec2_pretraining_no_trainer.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
--dataset_name patrickvonplaten/librispeech_asr_dummy
--dataset_config_names clean
--dataset_split_names validation
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--preprocessing_num_workers 16
--max_train_steps 5
--validation_split_percentage 5
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_wav2vec2_pretraining_no_trainer.main()
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
self.assertIsNotNone(model)

View File

@ -48,13 +48,13 @@ def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
device: torch.device,
attention_mask: Optional[torch.tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> torch.tensor:
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
ASR <https://arxiv.org/abs/1904.08779>`__.
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
on CPU as part of the preprocessing during training.
Args:
shape: the the shape for which to compute masks.
@ -64,7 +64,6 @@ def _compute_mask_indices(
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
"""
batch_size, sequence_length = shape
@ -76,42 +75,64 @@ def _compute_mask_indices(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
num_masked_spans = max(num_masked_spans, min_masks)
epsilon = np.random.rand(1).item()
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked indices <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
spec_aug_mask_idxs = []
# uniform distribution to sample from, make sure that offset samples are < sequence_length
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
max_num_masked_span = compute_num_masked_span(sequence_length)
# get random indices to mask
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = (
spec_aug_mask_idxs.unsqueeze(dim=-1)
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
offsets = (
torch.arange(mask_length, device=device)[None, None, :]
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# scatter indices to mask
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
if attention_mask is not None:
# make sure padded input ids cannot be masked
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
@ -257,6 +278,7 @@ class HubertFeatureExtractor(nn.Module):
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
def _freeze_parameters(self):
for param in self.parameters():
@ -264,8 +286,26 @@ class HubertFeatureExtractor(nn.Module):
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states
@ -864,10 +904,10 @@ class HubertModel(HubertPreTrainedModel):
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
attention_mask=attention_mask,
min_masks=2,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
@ -876,9 +916,11 @@ class HubertModel(HubertPreTrainedModel):
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
:, None
].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states

View File

@ -14,6 +14,7 @@
# limitations under the License.
""" PyTorch Wav2Vec2 model. """
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@ -87,7 +88,7 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
Args:
loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`):
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
@ -107,6 +108,10 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
contrastive_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
The contrastive loss (L_m) as stated in the `official paper <https://arxiv.org/pdf/2006.11477.pdf>`__ .
diversity_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
The diversity loss (L_d) as stated in the `official paper <https://arxiv.org/pdf/2006.11477.pdf>`__ .
"""
loss: Optional[torch.FloatTensor] = None
@ -115,19 +120,21 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
codevector_perplexity: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
contrastive_loss: Optional[torch.FloatTensor] = None
diversity_loss: Optional[torch.FloatTensor] = None
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
device: torch.device,
attention_mask: Optional[torch.tensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> torch.tensor:
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
ASR <https://arxiv.org/abs/1904.08779>`__.
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
on CPU as part of the preprocessing during training.
Args:
shape: the the shape for which to compute masks.
@ -137,7 +144,6 @@ def _compute_mask_indices(
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
"""
batch_size, sequence_length = shape
@ -149,46 +155,104 @@ def _compute_mask_indices(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
)
# compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
num_masked_spans = max(num_masked_spans, min_masks)
epsilon = np.random.rand(1).item()
# make sure num masked indices <= sequence_length
if num_masked_spans * mask_length > sequence_length:
num_masked_spans = sequence_length // mask_length
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked indices <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
spec_aug_mask_idxs = []
# uniform distribution to sample from, make sure that offset samples are < sequence_length
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
max_num_masked_span = compute_num_masked_span(sequence_length)
# get random indices to mask
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = (
spec_aug_mask_idxs.unsqueeze(dim=-1)
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
offsets = (
torch.arange(mask_length, device=device)[None, None, :]
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# scatter indices to mask
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
if attention_mask is not None:
# make sure padded input ids cannot be masked
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
def _sample_negative_indices(
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length = features_shape
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
)
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
return sampled_negative_indices
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@ -326,6 +390,7 @@ class Wav2Vec2FeatureExtractor(nn.Module):
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
self.gradient_checkpointing = False
def _freeze_parameters(self):
for param in self.parameters():
@ -333,8 +398,26 @@ class Wav2Vec2FeatureExtractor(nn.Module):
def forward(self, input_values):
hidden_states = input_values[:, None]
# make sure hidden_states require grad for gradient_checkpointing
if self.training:
hidden_states.requires_grad = True
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states,
)
else:
hidden_states = conv_layer(hidden_states)
return hidden_states
@ -778,7 +861,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 1
self.temperature = 2
def set_temperature(self, temperature: int):
self.temperature = temperature
@ -844,8 +927,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -854,22 +937,31 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
module.weight_proj.bias.data.zero_()
nn.init.uniform_(module.codevectors)
elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
nn.init.normal_(
module.conv.weight,
mean=0,
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
)
nn.init.constant_(module.conv.bias, 0)
elif isinstance(module, Wav2Vec2FeatureProjection):
k = math.sqrt(1 / module.projection.in_features)
nn.init.uniform_(module.projection.weight, a=-k, b=k)
nn.init.uniform_(module.projection.bias, a=-k, b=k)
elif isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight.data)
nn.init.kaiming_normal_(module.weight)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)):
module.gradient_checkpointing = value
if module.bias is not None:
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-k, b=k)
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
@ -898,6 +990,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureExtractor)):
module.gradient_checkpointing = value
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
@ -1001,10 +1097,10 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
device=hidden_states.device,
attention_mask=attention_mask,
min_masks=2,
)
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
@ -1013,9 +1109,11 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
device=hidden_states.device,
)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
:, None
].expand(-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
@ -1101,11 +1199,13 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.init_weights()
# make sure that project_hid & project_q are initialized like normal linear layers
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
def set_gumbel_temperature(self, temperature: int):
"""
Set the Gumbel softmax temperature to a given value. Only necessary for training
@ -1119,61 +1219,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
"""
self.wav2vec2.feature_extractor._freeze_parameters()
@staticmethod
def _sample_negatives(
features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
):
"""
Sample `num_negatives` vectors from feature vectors.
"""
batch_size, sequence_length, hidden_size = features.shape
if sequence_length <= 1:
raise ValueError(
f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
)
features = features.view(-1, hidden_size) # BTC => (BxT)C
with torch.no_grad():
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = []
for batch_idx in range(batch_size):
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
sampled_indices_slice = torch.randint(
0, high, size=(num_negatives * sequence_length,), device=features.device
)
sampled_negative_indices.append(sampled_indices_slice)
sampled_negative_indices = torch.stack(sampled_negative_indices)
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
feature_indices = (
torch.arange(sequence_length, device=features.device)[:, None]
.expand(sequence_length, num_negatives)
.flatten()
)
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
# correct for batch size
for batch_idx in range(1, batch_size):
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
# take negative vectors from sampled indices
sampled_negatives = features[sampled_negative_indices.view(-1)]
sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
2, 0, 1, 3
)
return sampled_negatives
@staticmethod
def compute_contrastive_logits(
target_features: torch.FloatTensor,
negative_features: torch.FloatTensor,
predicted_features: torch.FloatTensor,
temperature: int = 1,
temperature: int = 0.1,
):
"""
Compute logits for contrastive loss based using cosine similarity as the distance measure between
@ -1196,6 +1247,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
input_values,
attention_mask=None,
mask_time_indices=None,
sampled_negative_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@ -1204,6 +1256,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
masked extracted features in `config.proj_codevector_dim` space.
sampled_negative_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length, num_negatives)`, `optional`):
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
Required input for pre-training.
Returns:
@ -1270,21 +1325,30 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
# 2. quantize all (unmasked) extracted features and project to final vq dim
extract_features = self.dropout_features(outputs[1])
quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
quantized_features, codevector_perplexity = self.quantizer(
extract_features, mask_time_indices=mask_time_indices
)
quantized_features = self.project_q(quantized_features)
loss = None
if self.training:
loss = contrastive_loss = diversity_loss = None
if sampled_negative_indices is not None:
batch_size, sequence_length, hidden_size = quantized_features.shape
# for training, we sample negatives
# 3. sample K negatives (distractors) quantized states for contrastive loss
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
if attention_mask is not None:
# compute reduced attention_mask correponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
negative_quantized_features = self._sample_negatives(
quantized_features, self.config.num_negatives, attention_mask=attention_mask
)
# sample negative quantized vectors BTC => (BxT)C
negative_quantized_features = quantized_features.view(-1, hidden_size)[
sampled_negative_indices.long().view(-1)
]
negative_quantized_features = negative_quantized_features.view(
batch_size, sequence_length, -1, hidden_size
).permute(2, 0, 1, 3)
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
@ -1298,18 +1362,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
# its cosine similarity will be masked
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="sum")
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
# 7. compute diversity loss: \mathbf{L}_d
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
@ -1326,6 +1391,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
codevector_perplexity=codevector_perplexity,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
contrastive_loss=contrastive_loss,
diversity_loss=diversity_loss,
)

View File

@ -586,7 +586,8 @@ class HubertUtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
@ -596,7 +597,8 @@ class HubertUtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):

View File

@ -40,7 +40,11 @@ if is_torch_available():
Wav2Vec2Model,
Wav2Vec2Processor,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2GumbelVectorQuantizer,
_compute_mask_indices,
_sample_negative_indices,
)
class Wav2Vec2ModelTester:
@ -405,6 +409,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@ -605,6 +615,12 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
"masked_spec_embed",
"codevectors",
"quantizer.weight_proj.weight",
"project_hid.weight",
"project_hid.bias",
"project_q.weight",
"project_q.bias",
"feature_projection.projection.weight",
"feature_projection.projection.bias",
]
if param.requires_grad:
if any([x in name for x in uniform_init_parms]):
@ -640,28 +656,37 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
loss = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
).loss
# more losses
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy())
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
loss_more_masked = model(
inputs_dict["input_values"],
attention_mask=inputs_dict["attention_mask"],
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
).loss
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
@ -727,7 +752,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
@ -737,7 +763,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = torch.from_numpy(mask).to(torch_device)
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
for batch_sum in mask.sum(axis=-1):
@ -753,8 +780,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
mask = torch.from_numpy(mask).to(torch_device)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
@ -785,8 +813,11 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
# sample negative indices
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
@ -796,15 +827,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
def test_sample_negatives_with_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[-1, sequence_length // 2 :] = 0
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
@ -812,9 +843,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
# sample negative indices
sampled_negative_indices = _sample_negative_indices(
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
self.assertTrue((negatives >= 0).all().item())
@ -924,16 +961,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
# Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works
# correctly
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
def test_inference_integration(self):
return
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
model.to(torch_device)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-base", return_attention_mask=True
)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
input_speech = self._load_datasamples(2)
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
@ -943,19 +975,18 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
torch.manual_seed(0)
np.random.seed(4)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
)
@ -965,14 +996,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# retrieve cosine sim of masked features
cosine_sim_masked = cosine_sim[mask_time_indices]
# cosine similarity of model is all > 0.5 as model is
# pre-trained on contrastive loss
# fmt: off
expected_cosine_sim_masked = torch.tensor(
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
0.6997],
device=torch_device,
)
expected_cosine_sim_masked = torch.tensor([
0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565,
0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258,
0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774,
0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526,
0.7768, 0.4898, 0.5393, 0.8183
], device=torch_device)
# fmt: on
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
@ -997,9 +1030,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
with torch.no_grad():
outputs = model(
@ -1064,28 +1097,36 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
)
torch.manual_seed(0)
np.random.seed(0)
mask_time_indices = _compute_mask_indices(
features_shape,
model.config.mask_time_prob,
model.config.mask_time_length,
device=inputs_dict["input_values"].device,
min_masks=2,
).to(torch_device)
)
sampled_negative_indices = _sample_negative_indices(
mask_time_indices.shape, model.config.num_negatives, mask_time_indices
)
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
with torch.no_grad():
outputs = model(
inputs_dict.input_values.to(torch_device),
attention_mask=inputs_dict.attention_mask.to(torch_device),
mask_time_indices=mask_time_indices,
sampled_negative_indices=sampled_negative_indices,
)
# check diversity loss
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3)
# check overall loss (contrastive loss + diversity loss)
expected_loss = 62.5170
expected_loss = 116.7094
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)