mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
[Speech Examples] Add pytorch speech pretraining (#13877)
* adapt wav2vec2 * add example * add files * adapt * remove bogus file * Apply suggestions from code review * adapt files more * upload changes * del old files * up * up * up * up * up * correct gradient checkpoitning * add readme * finish * finish * up * more fixes * up * up * add demo run to readme * up
This commit is contained in:
parent
3499728dc4
commit
d45fc7da3d
@ -3,6 +3,7 @@ scikit-learn
|
||||
seqeval
|
||||
psutil
|
||||
sacrebleu >= 1.4.12
|
||||
accelerate >= 0.5.0
|
||||
rouge-score
|
||||
tensorflow_datasets
|
||||
matplotlib
|
||||
|
124
examples/pytorch/speech-pretraining/README.md
Normal file
124
examples/pytorch/speech-pretraining/README.md
Normal file
@ -0,0 +1,124 @@
|
||||
<!---
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Speech Recognition Pre-Training
|
||||
|
||||
|
||||
## Wav2Vec2 Speech Pre-Training
|
||||
|
||||
The script [`run_speech_wav2vec2_pretraining_no_trainer.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py) can be used to pre-train a [Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html?highlight=wav2vec2) model from scratch.
|
||||
|
||||
In the script [`run_speech_wav2vec2_pretraining_no_trainer`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py), a Wav2Vec2 model is pre-trained on audio data alone using [Wav2Vec2's contrastive loss objective](https://arxiv.org/abs/2006.11477).
|
||||
|
||||
The following examples show how to fine-tune a `"base"`-sized Wav2Vec2 model as well as a `"large"`-sized Wav2Vec2 model using [`accelerate`](https://github.com/huggingface/accelerate).
|
||||
|
||||
|
||||
---
|
||||
**NOTE 1**
|
||||
|
||||
Wav2Vec2's pre-training is known to be quite unstable.
|
||||
It is advised to do a couple of test runs with a smaller dataset,
|
||||
*i.e.* `--dataset_config_names clean clean`, `--dataset_split_names validation test`
|
||||
to find good hyper-parameters for `learning_rate`, `batch_size`, `num_warmup_steps`,
|
||||
and the optimizer.
|
||||
A good metric to observe during training is the gradient norm which should ideally be between 0.5 and 2.
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
**NOTE 2**
|
||||
|
||||
When training a model on large datasets it is recommended to run the data preprocessing
|
||||
in a first run in a **non-distributed** mode via `--preprocessing_only` so that
|
||||
when running the model in **distributed** mode in a second step the preprocessed data
|
||||
can easily be loaded on each distributed device.
|
||||
|
||||
---
|
||||
|
||||
### Demo
|
||||
|
||||
In this demo run we pre-train a `"base-sized"` Wav2Vec2 model simply only on the validation
|
||||
and test data of [librispeech_asr](https://huggingface.co/datasets/librispeech_asr).
|
||||
|
||||
The demo is run on two Titan RTX (24 GB RAM each). In case you have less RAM available
|
||||
per device, consider reducing `--batch_size` and/or the `--max_duration_in_seconds`.
|
||||
|
||||
|
||||
```bash
|
||||
accelerate launch run_wav2vec2_pretraining_no_trainer.py \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_names clean clean \
|
||||
--dataset_split_names validation test \
|
||||
--model_name_or_path="patrickvonplaten/wav2vec2-base-v2" \
|
||||
--output_dir="./wav2vec2-pretrained-demo" \
|
||||
--max_train_steps="20000" \
|
||||
--num_warmup_steps="32000" \
|
||||
--gradient_accumulation_steps="8" \
|
||||
--learning_rate="0.005" \
|
||||
--weight_decay="0.01" \
|
||||
--max_duration_in_seconds="20.0" \
|
||||
--min_duration_in_seconds="2.0" \
|
||||
--logging_steps="1" \
|
||||
--saving_steps="10000" \
|
||||
--per_device_train_batch_size="8" \
|
||||
--per_device_eval_batch_size="8" \
|
||||
--adam_beta1="0.9" \
|
||||
--adam_beta2="0.98" \
|
||||
--adam_epsilon="1e-06" \
|
||||
--gradient_checkpointing \
|
||||
```
|
||||
|
||||
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/wav2vec2-pretrained-demo/reports/Wav2Vec2-PreTraining-Demo-Run--VmlldzoxMDk3MjAw?accessToken=oa05s1y57lizo2ocxy3k01g6db1u4pt8m6ur2n8nl4cb0ug02ms2cw313kb8ruch).
|
||||
|
||||
### Base
|
||||
|
||||
TODO (currently running...)
|
||||
|
||||
|
||||
### Large
|
||||
|
||||
To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60),
|
||||
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
|
||||
|
||||
```bash
|
||||
accelerate launch run_pretrain_no_trainer.py \
|
||||
--dataset_name=librispeech_asr \
|
||||
--dataset_config_names clean clean other \
|
||||
--dataset_split_names train.100 train.360 train.500 \
|
||||
--output_dir=./test \
|
||||
--max_train_steps=200000 \
|
||||
--num_warmup_steps=32000 \
|
||||
--gradient_accumulation_steps=8 \
|
||||
--learning_rate=0.001 \
|
||||
--weight_decay=0.01 \
|
||||
--max_duration_in_seconds=20.0 \
|
||||
--min_duration_in_seconds=2.0 \
|
||||
--model_name_or_path=./
|
||||
--logging_steps=1 \
|
||||
--saving_steps=10000 \
|
||||
--per_device_train_batch_size=2 \
|
||||
--per_device_eval_batch_size=4 \
|
||||
--adam_beta1=0.9 \
|
||||
--adam_beta2=0.98 \
|
||||
--adam_epsilon=1e-06 \
|
||||
--gradient_checkpointing \
|
||||
```
|
||||
|
||||
The experiment was run on 8 GPU V100 (16 GB RAM each) for 7 days.
|
||||
In case you have more than 8 GPUs available for a higher effective `batch_size`,
|
||||
it is recommended to increase the `learning_rate` to `0.005` for faster convergence.
|
||||
|
||||
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/pretraining-wav2vec2/reports/Wav2Vec2-Large--VmlldzoxMTAwODM4?accessToken=wm3qzcnldrwsa31tkvf2pdmilw3f63d4twtffs86ou016xjbyilh55uoi3mo1qzc) and the checkpoint pretrained for 120,000 steps can be accessed [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-repro-960h-libri-120k-steps)
|
4
examples/pytorch/speech-pretraining/requirements.txt
Normal file
4
examples/pytorch/speech-pretraining/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
datasets >= 1.12.0
|
||||
torch >= 1.5
|
||||
torchaudio
|
||||
accelerate >= 0.5.0
|
700
examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
Executable file
700
examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
Executable file
@ -0,0 +1,700 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
""" Pre-Training a 🤗 Wav2Vec2 model on unlabeled audio data """
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torchaudio
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
AdamW,
|
||||
SchedulerType,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForPreTraining,
|
||||
get_scheduler,
|
||||
is_wandb_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.file_utils import get_full_repo_name
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_names",
|
||||
nargs="+",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The configuration names of the dataset to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_split_names",
|
||||
nargs="+",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The names of the training data set splits to use (via the datasets library).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of processes to use for the preprocessing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_only",
|
||||
action="store_true",
|
||||
help="Only run the preprocessing script to be cached for future use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Where do you want to store the pretrained models downloaded from huggingface.co",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split_percentage",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Percentage of training data that should be used for validation if no validation is present in dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps between each logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--saving_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps between each logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio_column_name",
|
||||
type=str,
|
||||
default="file",
|
||||
help="Column in the dataset that contains speech file path. Defaults to 'file'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_train_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_device_eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Batch size (per device) for the evaluation dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="If True, use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_type",
|
||||
type=SchedulerType,
|
||||
default="linear",
|
||||
help="The scheduler type to use.",
|
||||
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--max_gumbel_temperature",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Maximum temperature for gumbel softmax.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_gumbel_temperature",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Minimum temperature for gumbel softmax.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gumbel_temperature_decay", type=float, default=0.999995, help="Decay of gumbel temperature during training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_duration_in_seconds",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Filter out audio files that are longer than `max_duration_in_seconds` seconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_duration_in_seconds",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Filter out audio files that are shorter than `min_duration_in_seconds` seconds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pad_to_multiple_of",
|
||||
type=int,
|
||||
default=None,
|
||||
help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Beta1 for AdamW optimizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="Beta2 for AdamW optimizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-8,
|
||||
help="Epsilon for AdamW optimizer",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.push_to_hub:
|
||||
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForWav2Vec2Pretraining:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received and prepare masked indices
|
||||
for self-supervised pretraining.
|
||||
|
||||
Args:
|
||||
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
|
||||
The Wav2Vec2 model used for pretraining. The data collator needs to have access
|
||||
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
|
||||
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
|
||||
The processor used for proccessing the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
model: Wav2Vec2ForPreTraining
|
||||
feature_extractor: Wav2Vec2FeatureExtractor
|
||||
padding: Union[bool, str] = "longest"
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# reformat list to dict and set to pytorch format
|
||||
batch = self.feature_extractor.pad(
|
||||
features,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
device = batch["input_values"].device
|
||||
batch_size = batch["input_values"].shape[0]
|
||||
|
||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||
|
||||
# make sure that no loss is computed on padded inputs
|
||||
if batch.get("attention_mask") is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
|
||||
mask_indices_seq_length, batch["attention_mask"]
|
||||
)
|
||||
|
||||
features_shape = (batch_size, mask_indices_seq_length)
|
||||
|
||||
# sample randomly masked indices
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
self.model.config.mask_time_prob,
|
||||
self.model.config.mask_time_length,
|
||||
attention_mask=batch.get("sub_attention_mask"),
|
||||
)
|
||||
|
||||
# sample negative indices
|
||||
sampled_negative_indices = _sample_negative_indices(
|
||||
features_shape,
|
||||
self.model.config.num_negatives,
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device)
|
||||
batch["sampled_negative_indices"] = torch.tensor(sampled_negative_indices, dtype=torch.long, device=device)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def multiply_grads(params, c):
|
||||
"""Multiplies grads by a constant *c*."""
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
if torch.is_tensor(c):
|
||||
c = c.to(p.grad.device)
|
||||
p.grad.data.mul_(c)
|
||||
|
||||
|
||||
def get_grad_norm(params, scale=1):
|
||||
"""Compute grad norm given a gradient scale."""
|
||||
total_norm = 0.0
|
||||
for p in params:
|
||||
if p.grad is not None:
|
||||
param_norm = (p.grad.detach().data / scale).norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
total_norm = total_norm ** 0.5
|
||||
return total_norm
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
args = parse_args()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
accelerator = Accelerator()
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
# accelerator.is_local_main_process is only True for one process per machine.
|
||||
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
|
||||
# set up weights and biases if available
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
wandb.init(project=args.output_dir.split("/")[-1])
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and not args.preprocessing_only:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 1. Download and create train, validation dataset
|
||||
# We load all dataset configuration and datset split pairs passed in
|
||||
# ``args.dataset_config_names`` and ``args.dataset_split_names``
|
||||
datasets_splits = []
|
||||
for dataset_config_name, train_split_name in zip(args.dataset_config_names, args.dataset_split_names):
|
||||
# load dataset
|
||||
dataset_split = load_dataset(
|
||||
args.dataset_name, dataset_config_name, split=train_split_name, cache_dir=args.cache_dir
|
||||
)
|
||||
datasets_splits.append(dataset_split)
|
||||
|
||||
# Next, we concatenate all configurations and splits into a single training dataset
|
||||
raw_datasets = DatasetDict()
|
||||
if len(datasets_splits) > 1:
|
||||
raw_datasets["train"] = concatenate_datasets(datasets_splits).shuffle(seed=args.seed)
|
||||
else:
|
||||
raw_datasets["train"] = datasets_splits[0]
|
||||
|
||||
# Take ``args.validation_split_percentage`` from the training dataset for the validation_split_percentage
|
||||
num_validation_samples = raw_datasets["train"].num_rows * args.validation_split_percentage // 100
|
||||
|
||||
if num_validation_samples == 0:
|
||||
raise ValueError(
|
||||
"`args.validation_split_percentage` is less than a single sample "
|
||||
f"for {len(raw_datasets['train'])} training samples. Increase "
|
||||
"`args.num_validation_split_percentage`. "
|
||||
)
|
||||
|
||||
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
|
||||
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
|
||||
|
||||
# 2. Preprocess audio: load, resample, normalize and truncate
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# only normalized-inputs-training is supported
|
||||
if not feature_extractor.do_normalize:
|
||||
raise ValueError(
|
||||
"Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``"
|
||||
)
|
||||
|
||||
# set max & min audio length in number of samples
|
||||
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
||||
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
||||
|
||||
resampler = None
|
||||
if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3":
|
||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
||||
resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate)
|
||||
|
||||
def prepare_dataset(batch):
|
||||
speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name])
|
||||
speech_array = speech_array.squeeze()
|
||||
|
||||
# if necessary resample audio
|
||||
if resampler is not None:
|
||||
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
||||
speech_array = resampler(speech_array)
|
||||
sampling_rate = resampler.new_freq
|
||||
|
||||
speech_array = speech_array.numpy()
|
||||
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True)
|
||||
batch["input_values"] = inputs.input_values[0]
|
||||
return batch
|
||||
|
||||
# load audio files into numpy arrays
|
||||
with accelerator.main_process_first():
|
||||
vectorized_datasets = raw_datasets.map(
|
||||
prepare_dataset,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=raw_datasets["train"].column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
)
|
||||
vectorized_datasets = vectorized_datasets.filter(
|
||||
lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache
|
||||
)
|
||||
|
||||
# for large datasets it is advised to run the preprocessing on a
|
||||
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||
# be a timeout when running the script in distributed mode.
|
||||
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
||||
# cached dataset
|
||||
if args.preprocessing_only:
|
||||
return
|
||||
|
||||
# 3. Load model
|
||||
config = Wav2Vec2Config.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# pretraining is only supported for "newer" stable layer norm architecture
|
||||
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
||||
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||
raise ValueError(
|
||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||
)
|
||||
|
||||
# initialize random model
|
||||
model = Wav2Vec2ForPreTraining(config)
|
||||
|
||||
# Activate gradient checkpointing if needed
|
||||
if args.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# 4. Define data collator, optimizer and scheduler
|
||||
data_collator = DataCollatorForWav2Vec2Pretraining(
|
||||
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=args.pad_to_multiple_of
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
vectorized_datasets["train"],
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
)
|
||||
eval_dataloader = DataLoader(
|
||||
vectorized_datasets["validation"], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = AdamW(
|
||||
list(model.parameters()),
|
||||
lr=args.learning_rate,
|
||||
betas=[args.adam_beta1, args.adam_beta2],
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
else:
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
|
||||
# 5. Train
|
||||
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
completed_steps = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
completed_steps = 0
|
||||
for epoch in range(args.num_train_epochs):
|
||||
model.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# compute num of losses
|
||||
num_losses = batch["mask_time_indices"].sum()
|
||||
sub_attention_mask = batch.pop("sub_attention_mask", None)
|
||||
sub_attention_mask = (
|
||||
sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"])
|
||||
)
|
||||
percent_masked = num_losses / sub_attention_mask.sum()
|
||||
|
||||
# forward
|
||||
outputs = model(**batch)
|
||||
|
||||
# divide loss by gradient accumulation steps since gradients
|
||||
# are accumulated for multiple backward passes in PyTorch
|
||||
loss = outputs.loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# make sure that `num_losses` is summed for distributed training
|
||||
# and average gradients over losses of all devices
|
||||
if accelerator.state.num_processes > 1:
|
||||
num_losses = accelerator.gather(num_losses).sum()
|
||||
gradient_multiplier = accelerator.state.num_processes / num_losses
|
||||
multiply_grads(model.module.parameters(), gradient_multiplier)
|
||||
else:
|
||||
multiply_grads(model.parameters(), 1 / num_losses)
|
||||
|
||||
# update step
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||
|
||||
# compute grad norm for monitoring
|
||||
scale = (
|
||||
accelerator.scaler._scale.item()
|
||||
if hasattr(accelerator, "scaler") and accelerator.scaler is not None
|
||||
else 1
|
||||
)
|
||||
if accelerator.state.num_processes > 1:
|
||||
grad_norm = get_grad_norm(model.module.parameters(), scale)
|
||||
else:
|
||||
grad_norm = get_grad_norm(model.parameters(), scale)
|
||||
|
||||
# update parameters
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if not accelerator.optimizer_step_was_skipped:
|
||||
lr_scheduler.step()
|
||||
elif accelerator.is_local_main_process:
|
||||
progress_bar.write(
|
||||
"Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..."
|
||||
)
|
||||
|
||||
# update gumbel temperature
|
||||
gumbel_temperature = max(
|
||||
args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps,
|
||||
args.min_gumbel_temperature,
|
||||
)
|
||||
if hasattr(model, "module"):
|
||||
model.module.set_gumbel_temperature(gumbel_temperature)
|
||||
else:
|
||||
model.set_gumbel_temperature(gumbel_temperature)
|
||||
|
||||
progress_bar.update(1)
|
||||
completed_steps += 1
|
||||
|
||||
# 6. Log all results
|
||||
if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0:
|
||||
loss.detach()
|
||||
outputs.contrastive_loss.detach()
|
||||
outputs.diversity_loss.detach()
|
||||
|
||||
if accelerator.state.num_processes > 1:
|
||||
loss = accelerator.gather(loss).sum()
|
||||
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum()
|
||||
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum()
|
||||
percent_masked = accelerator.gather(percent_masked).sum()
|
||||
|
||||
train_logs = {
|
||||
"loss": (loss * args.gradient_accumulation_steps) / num_losses,
|
||||
"constrast_loss": outputs.contrastive_loss / num_losses,
|
||||
"div_loss": outputs.diversity_loss / num_losses,
|
||||
"%_mask_idx": percent_masked / accelerator.num_processes,
|
||||
"ppl": outputs.codevector_perplexity,
|
||||
"lr": torch.tensor(optimizer.param_groups[0]["lr"]),
|
||||
"temp": torch.tensor(gumbel_temperature),
|
||||
"grad_norm": torch.tensor(grad_norm),
|
||||
}
|
||||
log_str = ""
|
||||
for k, v in train_logs.items():
|
||||
log_str += "| {}: {:.3e}".format(k, v.item())
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
progress_bar.write(log_str)
|
||||
if is_wandb_available():
|
||||
wandb.log(train_logs)
|
||||
|
||||
# save model every `args.saving_steps` steps
|
||||
if (step + 1) % (args.gradient_accumulation_steps * args.saving_steps) == 0:
|
||||
if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
|
||||
if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process:
|
||||
repo.push_to_hub(commit_message=f"Training in progress step {completed_steps}", blocking=False)
|
||||
|
||||
# if completed steps > `args.max_train_steps` stop
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# 7. Validate!
|
||||
model.eval()
|
||||
|
||||
# init logs
|
||||
val_logs = {
|
||||
"val_loss": 0,
|
||||
"val_contrastive_loss": 0,
|
||||
"val_diversity_loss": 0,
|
||||
"val_num_losses": 0,
|
||||
}
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
with torch.no_grad():
|
||||
batch.pop("sub_attention_mask", None)
|
||||
outputs = model(**batch)
|
||||
|
||||
val_logs["val_loss"] += outputs.loss
|
||||
val_logs["val_contrastive_loss"] += outputs.contrastive_loss
|
||||
val_logs["val_diversity_loss"] += outputs.diversity_loss
|
||||
val_logs["val_num_losses"] += batch["mask_time_indices"].sum()
|
||||
|
||||
# sum over devices in multi-processing
|
||||
if accelerator.num_processes > 1:
|
||||
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()}
|
||||
|
||||
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}
|
||||
|
||||
log_str = ""
|
||||
for k, v in val_logs.items():
|
||||
log_str += "| {}: {:.3e}".format(k, v.item())
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
progress_bar.write(log_str)
|
||||
if is_wandb_available():
|
||||
wandb.log(val_logs)
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -23,6 +23,7 @@ from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2ForPreTraining
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||
|
||||
@ -41,6 +42,7 @@ SRC_DIRS = [
|
||||
"image-classification",
|
||||
"speech-recognition",
|
||||
"audio-classification",
|
||||
"speech-pretraining",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@ -59,6 +61,7 @@ if SRC_DIRS is not None:
|
||||
import run_summarization
|
||||
import run_swag
|
||||
import run_translation
|
||||
import run_wav2vec2_pretraining_no_trainer
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -447,3 +450,32 @@ class ExamplesTests(TestCasePlus):
|
||||
run_audio_classification.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||
|
||||
def test_run_wav2vec2_pretraining(self):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_wav2vec2_pretraining_no_trainer.py
|
||||
--output_dir {tmp_dir}
|
||||
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||
--dataset_name patrickvonplaten/librispeech_asr_dummy
|
||||
--dataset_config_names clean
|
||||
--dataset_split_names validation
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 2
|
||||
--preprocessing_num_workers 16
|
||||
--max_train_steps 5
|
||||
--validation_split_percentage 5
|
||||
--seed 42
|
||||
""".split()
|
||||
|
||||
if is_cuda_and_apex_available():
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_wav2vec2_pretraining_no_trainer.main()
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
@ -48,13 +48,13 @@ def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
device: torch.device,
|
||||
attention_mask: Optional[torch.tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
min_masks: int = 0,
|
||||
) -> torch.tensor:
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
||||
ASR <https://arxiv.org/abs/1904.08779>`__.
|
||||
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
|
||||
on CPU as part of the preprocessing during training.
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
@ -64,7 +64,6 @@ def _compute_mask_indices(
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_length: size of the mask
|
||||
min_masks: minimum number of masked spans
|
||||
|
||||
"""
|
||||
batch_size, sequence_length = shape
|
||||
|
||||
@ -76,42 +75,64 @@ def _compute_mask_indices(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
|
||||
# compute number of masked spans in batch
|
||||
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
|
||||
num_masked_spans = max(num_masked_spans, min_masks)
|
||||
epsilon = np.random.rand(1).item()
|
||||
|
||||
# make sure num masked indices <= sequence_length
|
||||
if num_masked_spans * mask_length > sequence_length:
|
||||
num_masked_spans = sequence_length // mask_length
|
||||
def compute_num_masked_span(input_length):
|
||||
"""Given input length, compute how many spans should be masked"""
|
||||
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
||||
num_masked_span = max(num_masked_span, min_masks)
|
||||
|
||||
# make sure num masked indices <= sequence_length
|
||||
if num_masked_span * mask_length > sequence_length:
|
||||
num_masked_span = sequence_length // mask_length
|
||||
|
||||
return num_masked_span
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
||||
# SpecAugment mask to fill
|
||||
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
||||
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
|
||||
spec_aug_mask_idxs = []
|
||||
|
||||
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
||||
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
|
||||
max_num_masked_span = compute_num_masked_span(sequence_length)
|
||||
|
||||
# get random indices to mask
|
||||
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
|
||||
for input_length in input_lengths:
|
||||
# compute num of masked spans for this input
|
||||
num_masked_span = compute_num_masked_span(input_length)
|
||||
# get random indices to mask
|
||||
spec_aug_mask_idx = np.random.choice(
|
||||
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
||||
)
|
||||
|
||||
# pick first sampled index that will serve as a dummy index to pad vector
|
||||
dummy_mask_idx = spec_aug_mask_idx[0]
|
||||
|
||||
spec_aug_mask_idx = np.concatenate(
|
||||
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
||||
)
|
||||
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
||||
|
||||
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
||||
|
||||
# expand masked indices to masked spans
|
||||
spec_aug_mask_idxs = (
|
||||
spec_aug_mask_idxs.unsqueeze(dim=-1)
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
spec_aug_mask_idxs = np.broadcast_to(
|
||||
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
||||
)
|
||||
offsets = (
|
||||
torch.arange(mask_length, device=device)[None, None, :]
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
||||
|
||||
offsets = np.arange(mask_length)[None, None, :]
|
||||
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
||||
batch_size, max_num_masked_span * mask_length
|
||||
)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||||
|
||||
# scatter indices to mask
|
||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded input ids cannot be masked
|
||||
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
|
||||
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||||
|
||||
return spec_aug_mask
|
||||
|
||||
@ -257,6 +278,7 @@ class HubertFeatureExtractor(nn.Module):
|
||||
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
||||
)
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
@ -264,8 +286,26 @@ class HubertFeatureExtractor(nn.Module):
|
||||
|
||||
def forward(self, input_values):
|
||||
hidden_states = input_values[:, None]
|
||||
|
||||
# make sure hidden_states require grad for gradient_checkpointing
|
||||
if self.training:
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(conv_layer),
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -864,10 +904,10 @@ class HubertModel(HubertPreTrainedModel):
|
||||
(batch_size, sequence_length),
|
||||
mask_prob=self.config.mask_time_prob,
|
||||
mask_length=self.config.mask_time_length,
|
||||
device=hidden_states.device,
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
if self.config.mask_feature_prob > 0 and self.training:
|
||||
@ -876,9 +916,11 @@ class HubertModel(HubertPreTrainedModel):
|
||||
(batch_size, hidden_size),
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
|
||||
:, None
|
||||
].expand(-1, sequence_length, -1)
|
||||
hidden_states[mask_feature_indices] = 0
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Wav2Vec2 model. """
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
@ -87,7 +88,7 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
|
||||
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
|
||||
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
|
||||
@ -107,6 +108,10 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
contrastive_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
The contrastive loss (L_m) as stated in the `official paper <https://arxiv.org/pdf/2006.11477.pdf>`__ .
|
||||
diversity_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
The diversity loss (L_d) as stated in the `official paper <https://arxiv.org/pdf/2006.11477.pdf>`__ .
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@ -115,19 +120,21 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||
codevector_perplexity: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
contrastive_loss: Optional[torch.FloatTensor] = None
|
||||
diversity_loss: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
device: torch.device,
|
||||
attention_mask: Optional[torch.tensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
min_masks: int = 0,
|
||||
) -> torch.tensor:
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
||||
ASR <https://arxiv.org/abs/1904.08779>`__.
|
||||
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
|
||||
on CPU as part of the preprocessing during training.
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
@ -137,7 +144,6 @@ def _compute_mask_indices(
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_length: size of the mask
|
||||
min_masks: minimum number of masked spans
|
||||
|
||||
"""
|
||||
batch_size, sequence_length = shape
|
||||
|
||||
@ -149,46 +155,104 @@ def _compute_mask_indices(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
|
||||
# compute number of masked spans in batch
|
||||
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
|
||||
num_masked_spans = max(num_masked_spans, min_masks)
|
||||
epsilon = np.random.rand(1).item()
|
||||
|
||||
# make sure num masked indices <= sequence_length
|
||||
if num_masked_spans * mask_length > sequence_length:
|
||||
num_masked_spans = sequence_length // mask_length
|
||||
def compute_num_masked_span(input_length):
|
||||
"""Given input length, compute how many spans should be masked"""
|
||||
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
||||
num_masked_span = max(num_masked_span, min_masks)
|
||||
|
||||
# make sure num masked indices <= sequence_length
|
||||
if num_masked_span * mask_length > sequence_length:
|
||||
num_masked_span = sequence_length // mask_length
|
||||
|
||||
return num_masked_span
|
||||
|
||||
# compute number of masked spans in batch
|
||||
input_lengths = (
|
||||
attention_mask.sum(-1).detach().tolist()
|
||||
if attention_mask is not None
|
||||
else [sequence_length for _ in range(batch_size)]
|
||||
)
|
||||
|
||||
# SpecAugment mask to fill
|
||||
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
||||
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
|
||||
spec_aug_mask_idxs = []
|
||||
|
||||
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
||||
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
|
||||
max_num_masked_span = compute_num_masked_span(sequence_length)
|
||||
|
||||
# get random indices to mask
|
||||
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
|
||||
for input_length in input_lengths:
|
||||
# compute num of masked spans for this input
|
||||
num_masked_span = compute_num_masked_span(input_length)
|
||||
# get random indices to mask
|
||||
spec_aug_mask_idx = np.random.choice(
|
||||
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
||||
)
|
||||
|
||||
# pick first sampled index that will serve as a dummy index to pad vector
|
||||
dummy_mask_idx = spec_aug_mask_idx[0]
|
||||
|
||||
spec_aug_mask_idx = np.concatenate(
|
||||
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
||||
)
|
||||
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
||||
|
||||
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
||||
|
||||
# expand masked indices to masked spans
|
||||
spec_aug_mask_idxs = (
|
||||
spec_aug_mask_idxs.unsqueeze(dim=-1)
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
spec_aug_mask_idxs = np.broadcast_to(
|
||||
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
||||
)
|
||||
offsets = (
|
||||
torch.arange(mask_length, device=device)[None, None, :]
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
||||
|
||||
offsets = np.arange(mask_length)[None, None, :]
|
||||
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
||||
batch_size, max_num_masked_span * mask_length
|
||||
)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||||
|
||||
# scatter indices to mask
|
||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded input ids cannot be masked
|
||||
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
|
||||
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||||
|
||||
return spec_aug_mask
|
||||
|
||||
|
||||
def _sample_negative_indices(
|
||||
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
||||
):
|
||||
"""
|
||||
Sample `num_negatives` vectors from feature vectors.
|
||||
"""
|
||||
batch_size, sequence_length = features_shape
|
||||
|
||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||
sequence_length_range = np.arange(sequence_length)
|
||||
|
||||
# get `num_negatives` random vector indices from the same utterance
|
||||
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
||||
|
||||
mask_time_indices = (
|
||||
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
|
||||
)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
high = mask_time_indices[batch_idx].sum() - 1
|
||||
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
||||
|
||||
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
||||
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
||||
# avoid sampling the same positive vector, but keep the distribution uniform
|
||||
sampled_indices[sampled_indices >= feature_indices] += 1
|
||||
|
||||
# remap to actual indices
|
||||
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
||||
|
||||
# correct for batch size
|
||||
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
||||
|
||||
return sampled_negative_indices
|
||||
|
||||
|
||||
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
@ -326,6 +390,7 @@ class Wav2Vec2FeatureExtractor(nn.Module):
|
||||
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
||||
)
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
@ -333,8 +398,26 @@ class Wav2Vec2FeatureExtractor(nn.Module):
|
||||
|
||||
def forward(self, input_values):
|
||||
hidden_states = input_values[:, None]
|
||||
|
||||
# make sure hidden_states require grad for gradient_checkpointing
|
||||
if self.training:
|
||||
hidden_states.requires_grad = True
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(conv_layer),
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
hidden_states = conv_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -778,7 +861,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
||||
|
||||
# can be decayed for training
|
||||
self.temperature = 1
|
||||
self.temperature = 2
|
||||
|
||||
def set_temperature(self, temperature: int):
|
||||
self.temperature = temperature
|
||||
@ -844,8 +927,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = Wav2Vec2Config
|
||||
base_model_prefix = "wav2vec2"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -854,22 +937,31 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
||||
module.weight_proj.bias.data.zero_()
|
||||
nn.init.uniform_(module.codevectors)
|
||||
elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
|
||||
nn.init.normal_(
|
||||
module.conv.weight,
|
||||
mean=0,
|
||||
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
||||
)
|
||||
nn.init.constant_(module.conv.bias, 0)
|
||||
elif isinstance(module, Wav2Vec2FeatureProjection):
|
||||
k = math.sqrt(1 / module.projection.in_features)
|
||||
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
||||
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
||||
elif isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(module.weight.data)
|
||||
nn.init.kaiming_normal_(module.weight)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)):
|
||||
module.gradient_checkpointing = value
|
||||
if module.bias is not None:
|
||||
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||||
nn.init.uniform_(module.bias, a=-k, b=k)
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
"""
|
||||
@ -898,6 +990,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
return attention_mask
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureExtractor)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||
@ -1001,10 +1097,10 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
(batch_size, sequence_length),
|
||||
mask_prob=self.config.mask_time_prob,
|
||||
mask_length=self.config.mask_time_length,
|
||||
device=hidden_states.device,
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
if self.config.mask_feature_prob > 0 and self.training:
|
||||
@ -1013,9 +1109,11 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
(batch_size, hidden_size),
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
|
||||
:, None
|
||||
].expand(-1, sequence_length, -1)
|
||||
hidden_states[mask_feature_indices] = 0
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1101,11 +1199,13 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
||||
|
||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
|
||||
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
||||
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
||||
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
||||
|
||||
def set_gumbel_temperature(self, temperature: int):
|
||||
"""
|
||||
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
||||
@ -1119,61 +1219,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
@staticmethod
|
||||
def _sample_negatives(
|
||||
features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Sample `num_negatives` vectors from feature vectors.
|
||||
"""
|
||||
batch_size, sequence_length, hidden_size = features.shape
|
||||
if sequence_length <= 1:
|
||||
raise ValueError(
|
||||
f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
|
||||
)
|
||||
|
||||
features = features.view(-1, hidden_size) # BTC => (BxT)C
|
||||
|
||||
with torch.no_grad():
|
||||
# get `num_negatives` random vector indices from the same utterance
|
||||
sampled_negative_indices = []
|
||||
for batch_idx in range(batch_size):
|
||||
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
|
||||
sampled_indices_slice = torch.randint(
|
||||
0, high, size=(num_negatives * sequence_length,), device=features.device
|
||||
)
|
||||
sampled_negative_indices.append(sampled_indices_slice)
|
||||
|
||||
sampled_negative_indices = torch.stack(sampled_negative_indices)
|
||||
|
||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||
feature_indices = (
|
||||
torch.arange(sequence_length, device=features.device)[:, None]
|
||||
.expand(sequence_length, num_negatives)
|
||||
.flatten()
|
||||
)
|
||||
|
||||
# avoid sampling the same positive vector, but keep the distribution uniform
|
||||
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
|
||||
|
||||
# correct for batch size
|
||||
for batch_idx in range(1, batch_size):
|
||||
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
||||
|
||||
# take negative vectors from sampled indices
|
||||
sampled_negatives = features[sampled_negative_indices.view(-1)]
|
||||
sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
|
||||
2, 0, 1, 3
|
||||
)
|
||||
|
||||
return sampled_negatives
|
||||
|
||||
@staticmethod
|
||||
def compute_contrastive_logits(
|
||||
target_features: torch.FloatTensor,
|
||||
negative_features: torch.FloatTensor,
|
||||
predicted_features: torch.FloatTensor,
|
||||
temperature: int = 1,
|
||||
temperature: int = 0.1,
|
||||
):
|
||||
"""
|
||||
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
||||
@ -1196,6 +1247,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
mask_time_indices=None,
|
||||
sampled_negative_indices=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@ -1204,6 +1256,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
||||
masked extracted features in `config.proj_codevector_dim` space.
|
||||
sampled_negative_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length, num_negatives)`, `optional`):
|
||||
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
||||
Required input for pre-training.
|
||||
|
||||
Returns:
|
||||
|
||||
@ -1270,21 +1325,30 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
|
||||
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
||||
extract_features = self.dropout_features(outputs[1])
|
||||
quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask correponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
quantized_features, codevector_perplexity = self.quantizer(
|
||||
extract_features, mask_time_indices=mask_time_indices
|
||||
)
|
||||
quantized_features = self.project_q(quantized_features)
|
||||
|
||||
loss = None
|
||||
if self.training:
|
||||
loss = contrastive_loss = diversity_loss = None
|
||||
if sampled_negative_indices is not None:
|
||||
batch_size, sequence_length, hidden_size = quantized_features.shape
|
||||
|
||||
# for training, we sample negatives
|
||||
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
||||
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask correponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
negative_quantized_features = self._sample_negatives(
|
||||
quantized_features, self.config.num_negatives, attention_mask=attention_mask
|
||||
)
|
||||
# sample negative quantized vectors BTC => (BxT)C
|
||||
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
||||
sampled_negative_indices.long().view(-1)
|
||||
]
|
||||
negative_quantized_features = negative_quantized_features.view(
|
||||
batch_size, sequence_length, -1, hidden_size
|
||||
).permute(2, 0, 1, 3)
|
||||
|
||||
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
||||
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
||||
@ -1298,18 +1362,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
||||
# its cosine similarity will be masked
|
||||
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
||||
|
||||
if neg_is_pos.any():
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
|
||||
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
||||
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
||||
preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
||||
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
||||
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
||||
contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="sum")
|
||||
|
||||
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
||||
# 7. compute diversity loss: \mathbf{L}_d
|
||||
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
||||
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
|
||||
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
||||
|
||||
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
||||
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
||||
@ -1326,6 +1391,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
codevector_perplexity=codevector_perplexity,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
contrastive_loss=contrastive_loss,
|
||||
diversity_loss=diversity_loss,
|
||||
)
|
||||
|
||||
|
||||
|
@ -586,7 +586,8 @@ class HubertUtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
@ -596,7 +597,8 @@ class HubertUtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
|
@ -40,7 +40,11 @@ if is_torch_available():
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
||||
Wav2Vec2GumbelVectorQuantizer,
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@ -405,6 +409,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
"project_hid.weight",
|
||||
"project_hid.bias",
|
||||
"project_q.weight",
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@ -605,6 +615,12 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"masked_spec_embed",
|
||||
"codevectors",
|
||||
"quantizer.weight_proj.weight",
|
||||
"project_hid.weight",
|
||||
"project_hid.bias",
|
||||
"project_q.weight",
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@ -640,28 +656,37 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
features_shape = (
|
||||
inputs_dict["input_values"].shape[0],
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
|
||||
)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices)
|
||||
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
|
||||
loss = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
sampled_negative_indices=sampled_negative_indices,
|
||||
).loss
|
||||
|
||||
# more losses
|
||||
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
|
||||
|
||||
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy())
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
loss_more_masked = model(
|
||||
inputs_dict["input_values"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
mask_time_indices=mask_time_indices,
|
||||
sampled_negative_indices=sampled_negative_indices,
|
||||
).loss
|
||||
|
||||
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
||||
@ -727,7 +752,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
@ -737,7 +763,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
@ -753,8 +780,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
attention_mask[:2, sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
@ -785,8 +813,11 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
|
||||
|
||||
# sample negative indices
|
||||
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
|
||||
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
@ -796,15 +827,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
def test_sample_negatives_with_attn_mask(self):
|
||||
def test_sample_negatives_with_mask(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
# second half of last input tensor is padded
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
attention_mask[-1, sequence_length // 2 :] = 0
|
||||
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
@ -812,9 +843,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# replace masked feature vectors with -100 to test that those are not sampled
|
||||
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
|
||||
# sample negative indices
|
||||
sampled_negative_indices = _sample_negative_indices(
|
||||
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
|
||||
)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
|
||||
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
|
||||
|
||||
self.assertTrue((negatives >= 0).all().item())
|
||||
|
||||
@ -924,16 +961,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
# Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works
|
||||
# correctly
|
||||
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
|
||||
def test_inference_integration(self):
|
||||
return
|
||||
|
||||
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||
model.to(torch_device)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"facebook/wav2vec2-base", return_attention_mask=True
|
||||
)
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||
@ -943,19 +975,18 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(4)
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
)
|
||||
|
||||
@ -965,14 +996,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
# retrieve cosine sim of masked features
|
||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||
|
||||
# cosine similarity of model is all > 0.5 as model is
|
||||
# pre-trained on contrastive loss
|
||||
# fmt: off
|
||||
expected_cosine_sim_masked = torch.tensor(
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
|
||||
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
|
||||
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
|
||||
0.6997],
|
||||
device=torch_device,
|
||||
)
|
||||
expected_cosine_sim_masked = torch.tensor([
|
||||
0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565,
|
||||
0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258,
|
||||
0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774,
|
||||
0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526,
|
||||
0.7768, 0.4898, 0.5393, 0.8183
|
||||
], device=torch_device)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||
@ -997,9 +1030,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
@ -1064,28 +1097,36 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
features_shape,
|
||||
model.config.mask_time_prob,
|
||||
model.config.mask_time_length,
|
||||
device=inputs_dict["input_values"].device,
|
||||
min_masks=2,
|
||||
).to(torch_device)
|
||||
)
|
||||
sampled_negative_indices = _sample_negative_indices(
|
||||
mask_time_indices.shape, model.config.num_negatives, mask_time_indices
|
||||
)
|
||||
|
||||
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
inputs_dict.input_values.to(torch_device),
|
||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||
mask_time_indices=mask_time_indices,
|
||||
sampled_negative_indices=sampled_negative_indices,
|
||||
)
|
||||
|
||||
# check diversity loss
|
||||
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
|
||||
self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3)
|
||||
|
||||
# check overall loss (contrastive loss + diversity loss)
|
||||
expected_loss = 62.5170
|
||||
expected_loss = 116.7094
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user