mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[Speech Examples] Add pytorch speech pretraining (#13877)
* adapt wav2vec2 * add example * add files * adapt * remove bogus file * Apply suggestions from code review * adapt files more * upload changes * del old files * up * up * up * up * up * correct gradient checkpoitning * add readme * finish * finish * up * more fixes * up * up * add demo run to readme * up
This commit is contained in:
parent
3499728dc4
commit
d45fc7da3d
@ -3,6 +3,7 @@ scikit-learn
|
|||||||
seqeval
|
seqeval
|
||||||
psutil
|
psutil
|
||||||
sacrebleu >= 1.4.12
|
sacrebleu >= 1.4.12
|
||||||
|
accelerate >= 0.5.0
|
||||||
rouge-score
|
rouge-score
|
||||||
tensorflow_datasets
|
tensorflow_datasets
|
||||||
matplotlib
|
matplotlib
|
||||||
|
124
examples/pytorch/speech-pretraining/README.md
Normal file
124
examples/pytorch/speech-pretraining/README.md
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
<!---
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Speech Recognition Pre-Training
|
||||||
|
|
||||||
|
|
||||||
|
## Wav2Vec2 Speech Pre-Training
|
||||||
|
|
||||||
|
The script [`run_speech_wav2vec2_pretraining_no_trainer.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py) can be used to pre-train a [Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html?highlight=wav2vec2) model from scratch.
|
||||||
|
|
||||||
|
In the script [`run_speech_wav2vec2_pretraining_no_trainer`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py), a Wav2Vec2 model is pre-trained on audio data alone using [Wav2Vec2's contrastive loss objective](https://arxiv.org/abs/2006.11477).
|
||||||
|
|
||||||
|
The following examples show how to fine-tune a `"base"`-sized Wav2Vec2 model as well as a `"large"`-sized Wav2Vec2 model using [`accelerate`](https://github.com/huggingface/accelerate).
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
**NOTE 1**
|
||||||
|
|
||||||
|
Wav2Vec2's pre-training is known to be quite unstable.
|
||||||
|
It is advised to do a couple of test runs with a smaller dataset,
|
||||||
|
*i.e.* `--dataset_config_names clean clean`, `--dataset_split_names validation test`
|
||||||
|
to find good hyper-parameters for `learning_rate`, `batch_size`, `num_warmup_steps`,
|
||||||
|
and the optimizer.
|
||||||
|
A good metric to observe during training is the gradient norm which should ideally be between 0.5 and 2.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
---
|
||||||
|
**NOTE 2**
|
||||||
|
|
||||||
|
When training a model on large datasets it is recommended to run the data preprocessing
|
||||||
|
in a first run in a **non-distributed** mode via `--preprocessing_only` so that
|
||||||
|
when running the model in **distributed** mode in a second step the preprocessed data
|
||||||
|
can easily be loaded on each distributed device.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Demo
|
||||||
|
|
||||||
|
In this demo run we pre-train a `"base-sized"` Wav2Vec2 model simply only on the validation
|
||||||
|
and test data of [librispeech_asr](https://huggingface.co/datasets/librispeech_asr).
|
||||||
|
|
||||||
|
The demo is run on two Titan RTX (24 GB RAM each). In case you have less RAM available
|
||||||
|
per device, consider reducing `--batch_size` and/or the `--max_duration_in_seconds`.
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch run_wav2vec2_pretraining_no_trainer.py \
|
||||||
|
--dataset_name="librispeech_asr" \
|
||||||
|
--dataset_config_names clean clean \
|
||||||
|
--dataset_split_names validation test \
|
||||||
|
--model_name_or_path="patrickvonplaten/wav2vec2-base-v2" \
|
||||||
|
--output_dir="./wav2vec2-pretrained-demo" \
|
||||||
|
--max_train_steps="20000" \
|
||||||
|
--num_warmup_steps="32000" \
|
||||||
|
--gradient_accumulation_steps="8" \
|
||||||
|
--learning_rate="0.005" \
|
||||||
|
--weight_decay="0.01" \
|
||||||
|
--max_duration_in_seconds="20.0" \
|
||||||
|
--min_duration_in_seconds="2.0" \
|
||||||
|
--logging_steps="1" \
|
||||||
|
--saving_steps="10000" \
|
||||||
|
--per_device_train_batch_size="8" \
|
||||||
|
--per_device_eval_batch_size="8" \
|
||||||
|
--adam_beta1="0.9" \
|
||||||
|
--adam_beta2="0.98" \
|
||||||
|
--adam_epsilon="1e-06" \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
```
|
||||||
|
|
||||||
|
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/wav2vec2-pretrained-demo/reports/Wav2Vec2-PreTraining-Demo-Run--VmlldzoxMDk3MjAw?accessToken=oa05s1y57lizo2ocxy3k01g6db1u4pt8m6ur2n8nl4cb0ug02ms2cw313kb8ruch).
|
||||||
|
|
||||||
|
### Base
|
||||||
|
|
||||||
|
TODO (currently running...)
|
||||||
|
|
||||||
|
|
||||||
|
### Large
|
||||||
|
|
||||||
|
To pre-train `"large-sized"` Wav2Vec2 model, *e.g.* [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60),
|
||||||
|
on [librispeech_asr](https://huggingface.co/datasets/librispeech_asr), the following command can be run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch run_pretrain_no_trainer.py \
|
||||||
|
--dataset_name=librispeech_asr \
|
||||||
|
--dataset_config_names clean clean other \
|
||||||
|
--dataset_split_names train.100 train.360 train.500 \
|
||||||
|
--output_dir=./test \
|
||||||
|
--max_train_steps=200000 \
|
||||||
|
--num_warmup_steps=32000 \
|
||||||
|
--gradient_accumulation_steps=8 \
|
||||||
|
--learning_rate=0.001 \
|
||||||
|
--weight_decay=0.01 \
|
||||||
|
--max_duration_in_seconds=20.0 \
|
||||||
|
--min_duration_in_seconds=2.0 \
|
||||||
|
--model_name_or_path=./
|
||||||
|
--logging_steps=1 \
|
||||||
|
--saving_steps=10000 \
|
||||||
|
--per_device_train_batch_size=2 \
|
||||||
|
--per_device_eval_batch_size=4 \
|
||||||
|
--adam_beta1=0.9 \
|
||||||
|
--adam_beta2=0.98 \
|
||||||
|
--adam_epsilon=1e-06 \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
```
|
||||||
|
|
||||||
|
The experiment was run on 8 GPU V100 (16 GB RAM each) for 7 days.
|
||||||
|
In case you have more than 8 GPUs available for a higher effective `batch_size`,
|
||||||
|
it is recommended to increase the `learning_rate` to `0.005` for faster convergence.
|
||||||
|
|
||||||
|
The results of this run can be seen [here](https://wandb.ai/patrickvonplaten/pretraining-wav2vec2/reports/Wav2Vec2-Large--VmlldzoxMTAwODM4?accessToken=wm3qzcnldrwsa31tkvf2pdmilw3f63d4twtffs86ou016xjbyilh55uoi3mo1qzc) and the checkpoint pretrained for 120,000 steps can be accessed [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-repro-960h-libri-120k-steps)
|
4
examples/pytorch/speech-pretraining/requirements.txt
Normal file
4
examples/pytorch/speech-pretraining/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
datasets >= 1.12.0
|
||||||
|
torch >= 1.5
|
||||||
|
torchaudio
|
||||||
|
accelerate >= 0.5.0
|
700
examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
Executable file
700
examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py
Executable file
@ -0,0 +1,700 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
""" Pre-Training a 🤗 Wav2Vec2 model on unlabeled audio data """
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from huggingface_hub import Repository
|
||||||
|
from transformers import (
|
||||||
|
AdamW,
|
||||||
|
SchedulerType,
|
||||||
|
Wav2Vec2Config,
|
||||||
|
Wav2Vec2FeatureExtractor,
|
||||||
|
Wav2Vec2ForPreTraining,
|
||||||
|
get_scheduler,
|
||||||
|
is_wandb_available,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from transformers.file_utils import get_full_repo_name
|
||||||
|
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The name of the dataset to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_config_names",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The configuration names of the dataset to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_split_names",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The names of the training data set splits to use (via the datasets library).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocessing_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="The number of processes to use for the preprocessing.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocessing_only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only run the preprocessing script to be cached for future use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Where do you want to store the pretrained models downloaded from huggingface.co",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--validation_split_percentage",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Percentage of training data that should be used for validation if no validation is present in dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logging_steps",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of steps between each logging",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--saving_steps",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="Number of steps between each logging",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--audio_column_name",
|
||||||
|
type=str,
|
||||||
|
default="file",
|
||||||
|
help="Column in the dataset that contains speech file path. Defaults to 'file'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_name",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained config name or path if not the same as model_name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_train_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the training dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_eval_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the evaluation dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
||||||
|
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="If True, use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler_type",
|
||||||
|
type=SchedulerType,
|
||||||
|
default="linear",
|
||||||
|
help="The scheduler type to use.",
|
||||||
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_gumbel_temperature",
|
||||||
|
type=float,
|
||||||
|
default=2.0,
|
||||||
|
help="Maximum temperature for gumbel softmax.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_gumbel_temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="Minimum temperature for gumbel softmax.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gumbel_temperature_decay", type=float, default=0.999995, help="Decay of gumbel temperature during training."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_duration_in_seconds",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="Filter out audio files that are longer than `max_duration_in_seconds` seconds",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_duration_in_seconds",
|
||||||
|
type=float,
|
||||||
|
default=3.0,
|
||||||
|
help="Filter out audio files that are shorter than `min_duration_in_seconds` seconds",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pad_to_multiple_of",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="If set will pad the sequence to a multiple of the provided value. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_beta1",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="Beta1 for AdamW optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_beta2",
|
||||||
|
type=float,
|
||||||
|
default=0.999,
|
||||||
|
help="Beta2 for AdamW optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_epsilon",
|
||||||
|
type=float,
|
||||||
|
default=1e-8,
|
||||||
|
help="Epsilon for AdamW optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||||
|
)
|
||||||
|
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForWav2Vec2Pretraining:
|
||||||
|
"""
|
||||||
|
Data collator that will dynamically pad the inputs received and prepare masked indices
|
||||||
|
for self-supervised pretraining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:class:`~transformers.Wav2Vec2ForPreTraining`):
|
||||||
|
The Wav2Vec2 model used for pretraining. The data collator needs to have access
|
||||||
|
to config and ``_get_feat_extract_output_lengths`` function for correct padding.
|
||||||
|
feature_extractor (:class:`~transformers.Wav2Vec2FeatureExtractor`):
|
||||||
|
The processor used for proccessing the data.
|
||||||
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||||
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||||
|
among:
|
||||||
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||||
|
sequence if provided).
|
||||||
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||||
|
maximum acceptable input length for the model if that argument is not provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||||
|
different lengths).
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||||
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||||
|
If set will pad the sequence to a multiple of the provided value.
|
||||||
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||||
|
7.5 (Volta).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: Wav2Vec2ForPreTraining
|
||||||
|
feature_extractor: Wav2Vec2FeatureExtractor
|
||||||
|
padding: Union[bool, str] = "longest"
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
|
||||||
|
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||||
|
# reformat list to dict and set to pytorch format
|
||||||
|
batch = self.feature_extractor.pad(
|
||||||
|
features,
|
||||||
|
padding=self.padding,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
device = batch["input_values"].device
|
||||||
|
batch_size = batch["input_values"].shape[0]
|
||||||
|
|
||||||
|
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||||
|
|
||||||
|
# make sure that no loss is computed on padded inputs
|
||||||
|
if batch.get("attention_mask") is not None:
|
||||||
|
# compute real output lengths according to convolution formula
|
||||||
|
batch["sub_attention_mask"] = self.model._get_feature_vector_attention_mask(
|
||||||
|
mask_indices_seq_length, batch["attention_mask"]
|
||||||
|
)
|
||||||
|
|
||||||
|
features_shape = (batch_size, mask_indices_seq_length)
|
||||||
|
|
||||||
|
# sample randomly masked indices
|
||||||
|
mask_time_indices = _compute_mask_indices(
|
||||||
|
features_shape,
|
||||||
|
self.model.config.mask_time_prob,
|
||||||
|
self.model.config.mask_time_length,
|
||||||
|
attention_mask=batch.get("sub_attention_mask"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# sample negative indices
|
||||||
|
sampled_negative_indices = _sample_negative_indices(
|
||||||
|
features_shape,
|
||||||
|
self.model.config.num_negatives,
|
||||||
|
mask_time_indices=mask_time_indices,
|
||||||
|
)
|
||||||
|
batch["mask_time_indices"] = torch.tensor(mask_time_indices, dtype=torch.long, device=device)
|
||||||
|
batch["sampled_negative_indices"] = torch.tensor(sampled_negative_indices, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def multiply_grads(params, c):
|
||||||
|
"""Multiplies grads by a constant *c*."""
|
||||||
|
for p in params:
|
||||||
|
if p.grad is not None:
|
||||||
|
if torch.is_tensor(c):
|
||||||
|
c = c.to(p.grad.device)
|
||||||
|
p.grad.data.mul_(c)
|
||||||
|
|
||||||
|
|
||||||
|
def get_grad_norm(params, scale=1):
|
||||||
|
"""Compute grad norm given a gradient scale."""
|
||||||
|
total_norm = 0.0
|
||||||
|
for p in params:
|
||||||
|
if p.grad is not None:
|
||||||
|
param_norm = (p.grad.detach().data / scale).norm(2)
|
||||||
|
total_norm += param_norm.item() ** 2
|
||||||
|
total_norm = total_norm ** 0.5
|
||||||
|
return total_norm
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# See all possible arguments in src/transformers/args.py
|
||||||
|
# or by passing the --help flag to this script.
|
||||||
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||||
|
accelerator = Accelerator()
|
||||||
|
logger.info(accelerator.state)
|
||||||
|
|
||||||
|
# Setup logging, we only want one process per machine to log things on the screen.
|
||||||
|
# accelerator.is_local_main_process is only True for one process per machine.
|
||||||
|
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
|
||||||
|
# set up weights and biases if available
|
||||||
|
if is_wandb_available():
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
wandb.init(project=args.output_dir.split("/")[-1])
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.push_to_hub and not args.preprocessing_only:
|
||||||
|
if args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||||
|
else:
|
||||||
|
repo_name = args.hub_model_id
|
||||||
|
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||||
|
elif args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# 1. Download and create train, validation dataset
|
||||||
|
# We load all dataset configuration and datset split pairs passed in
|
||||||
|
# ``args.dataset_config_names`` and ``args.dataset_split_names``
|
||||||
|
datasets_splits = []
|
||||||
|
for dataset_config_name, train_split_name in zip(args.dataset_config_names, args.dataset_split_names):
|
||||||
|
# load dataset
|
||||||
|
dataset_split = load_dataset(
|
||||||
|
args.dataset_name, dataset_config_name, split=train_split_name, cache_dir=args.cache_dir
|
||||||
|
)
|
||||||
|
datasets_splits.append(dataset_split)
|
||||||
|
|
||||||
|
# Next, we concatenate all configurations and splits into a single training dataset
|
||||||
|
raw_datasets = DatasetDict()
|
||||||
|
if len(datasets_splits) > 1:
|
||||||
|
raw_datasets["train"] = concatenate_datasets(datasets_splits).shuffle(seed=args.seed)
|
||||||
|
else:
|
||||||
|
raw_datasets["train"] = datasets_splits[0]
|
||||||
|
|
||||||
|
# Take ``args.validation_split_percentage`` from the training dataset for the validation_split_percentage
|
||||||
|
num_validation_samples = raw_datasets["train"].num_rows * args.validation_split_percentage // 100
|
||||||
|
|
||||||
|
if num_validation_samples == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`args.validation_split_percentage` is less than a single sample "
|
||||||
|
f"for {len(raw_datasets['train'])} training samples. Increase "
|
||||||
|
"`args.num_validation_split_percentage`. "
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_datasets["validation"] = raw_datasets["train"].select(range(num_validation_samples))
|
||||||
|
raw_datasets["train"] = raw_datasets["train"].select(range(num_validation_samples, raw_datasets["train"].num_rows))
|
||||||
|
|
||||||
|
# 2. Preprocess audio: load, resample, normalize and truncate
|
||||||
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||||
|
|
||||||
|
# only normalized-inputs-training is supported
|
||||||
|
if not feature_extractor.do_normalize:
|
||||||
|
raise ValueError(
|
||||||
|
"Training is only supported for normalized inputs. " "Make sure ``feature_extractor.do_normalize == True``"
|
||||||
|
)
|
||||||
|
|
||||||
|
# set max & min audio length in number of samples
|
||||||
|
max_length = int(args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
||||||
|
min_length = int(args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
||||||
|
|
||||||
|
resampler = None
|
||||||
|
if raw_datasets["train"][args.audio_column_name][0].split(".")[-1] == "mp3":
|
||||||
|
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
||||||
|
resampler = torchaudio.transforms.Resample(48_000, feature_extractor.sampling_rate)
|
||||||
|
|
||||||
|
def prepare_dataset(batch):
|
||||||
|
speech_array, sampling_rate = torchaudio.load(batch[args.audio_column_name])
|
||||||
|
speech_array = speech_array.squeeze()
|
||||||
|
|
||||||
|
# if necessary resample audio
|
||||||
|
if resampler is not None:
|
||||||
|
# TODO(PVP) - remove hard-coded 48_000 after audio feature is merged
|
||||||
|
speech_array = resampler(speech_array)
|
||||||
|
sampling_rate = resampler.new_freq
|
||||||
|
|
||||||
|
speech_array = speech_array.numpy()
|
||||||
|
inputs = feature_extractor(speech_array, sampling_rate=sampling_rate, max_length=max_length, truncation=True)
|
||||||
|
batch["input_values"] = inputs.input_values[0]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# load audio files into numpy arrays
|
||||||
|
with accelerator.main_process_first():
|
||||||
|
vectorized_datasets = raw_datasets.map(
|
||||||
|
prepare_dataset,
|
||||||
|
num_proc=args.preprocessing_num_workers,
|
||||||
|
remove_columns=raw_datasets["train"].column_names,
|
||||||
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
|
)
|
||||||
|
vectorized_datasets = vectorized_datasets.filter(
|
||||||
|
lambda x: len(x["input_values"]) > min_length, load_from_cache_file=not args.overwrite_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# for large datasets it is advised to run the preprocessing on a
|
||||||
|
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||||
|
# be a timeout when running the script in distributed mode.
|
||||||
|
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
|
||||||
|
# cached dataset
|
||||||
|
if args.preprocessing_only:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 3. Load model
|
||||||
|
config = Wav2Vec2Config.from_pretrained(args.model_name_or_path)
|
||||||
|
|
||||||
|
# pretraining is only supported for "newer" stable layer norm architecture
|
||||||
|
# apply_spec_augment has to be True, mask_feature_prob has to be 0.0
|
||||||
|
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||||
|
raise ValueError(
|
||||||
|
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize random model
|
||||||
|
model = Wav2Vec2ForPreTraining(config)
|
||||||
|
|
||||||
|
# Activate gradient checkpointing if needed
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# 4. Define data collator, optimizer and scheduler
|
||||||
|
data_collator = DataCollatorForWav2Vec2Pretraining(
|
||||||
|
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=args.pad_to_multiple_of
|
||||||
|
)
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
vectorized_datasets["train"],
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
batch_size=args.per_device_train_batch_size,
|
||||||
|
)
|
||||||
|
eval_dataloader = DataLoader(
|
||||||
|
vectorized_datasets["validation"], collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer = AdamW(
|
||||||
|
list(model.parameters()),
|
||||||
|
lr=args.learning_rate,
|
||||||
|
betas=[args.adam_beta1, args.adam_beta2],
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
|
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, eval_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scheduler and math around the number of training steps.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
if args.max_train_steps is None:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
else:
|
||||||
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
name=args.lr_scheduler_type,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=args.num_warmup_steps,
|
||||||
|
num_training_steps=args.max_train_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Train
|
||||||
|
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
|
||||||
|
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||||
|
completed_steps = 0
|
||||||
|
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
completed_steps = 0
|
||||||
|
for epoch in range(args.num_train_epochs):
|
||||||
|
model.train()
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
# compute num of losses
|
||||||
|
num_losses = batch["mask_time_indices"].sum()
|
||||||
|
sub_attention_mask = batch.pop("sub_attention_mask", None)
|
||||||
|
sub_attention_mask = (
|
||||||
|
sub_attention_mask if sub_attention_mask is not None else torch.ones_like(batch["mask_time_indices"])
|
||||||
|
)
|
||||||
|
percent_masked = num_losses / sub_attention_mask.sum()
|
||||||
|
|
||||||
|
# forward
|
||||||
|
outputs = model(**batch)
|
||||||
|
|
||||||
|
# divide loss by gradient accumulation steps since gradients
|
||||||
|
# are accumulated for multiple backward passes in PyTorch
|
||||||
|
loss = outputs.loss / args.gradient_accumulation_steps
|
||||||
|
accelerator.backward(loss)
|
||||||
|
|
||||||
|
# make sure that `num_losses` is summed for distributed training
|
||||||
|
# and average gradients over losses of all devices
|
||||||
|
if accelerator.state.num_processes > 1:
|
||||||
|
num_losses = accelerator.gather(num_losses).sum()
|
||||||
|
gradient_multiplier = accelerator.state.num_processes / num_losses
|
||||||
|
multiply_grads(model.module.parameters(), gradient_multiplier)
|
||||||
|
else:
|
||||||
|
multiply_grads(model.parameters(), 1 / num_losses)
|
||||||
|
|
||||||
|
# update step
|
||||||
|
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
||||||
|
|
||||||
|
# compute grad norm for monitoring
|
||||||
|
scale = (
|
||||||
|
accelerator.scaler._scale.item()
|
||||||
|
if hasattr(accelerator, "scaler") and accelerator.scaler is not None
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
if accelerator.state.num_processes > 1:
|
||||||
|
grad_norm = get_grad_norm(model.module.parameters(), scale)
|
||||||
|
else:
|
||||||
|
grad_norm = get_grad_norm(model.parameters(), scale)
|
||||||
|
|
||||||
|
# update parameters
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if not accelerator.optimizer_step_was_skipped:
|
||||||
|
lr_scheduler.step()
|
||||||
|
elif accelerator.is_local_main_process:
|
||||||
|
progress_bar.write(
|
||||||
|
"Gradients have overflown - skipping update step... " f"Updating gradient scale to {scale}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# update gumbel temperature
|
||||||
|
gumbel_temperature = max(
|
||||||
|
args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps,
|
||||||
|
args.min_gumbel_temperature,
|
||||||
|
)
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model.module.set_gumbel_temperature(gumbel_temperature)
|
||||||
|
else:
|
||||||
|
model.set_gumbel_temperature(gumbel_temperature)
|
||||||
|
|
||||||
|
progress_bar.update(1)
|
||||||
|
completed_steps += 1
|
||||||
|
|
||||||
|
# 6. Log all results
|
||||||
|
if (step + 1) % (args.gradient_accumulation_steps * args.logging_steps) == 0:
|
||||||
|
loss.detach()
|
||||||
|
outputs.contrastive_loss.detach()
|
||||||
|
outputs.diversity_loss.detach()
|
||||||
|
|
||||||
|
if accelerator.state.num_processes > 1:
|
||||||
|
loss = accelerator.gather(loss).sum()
|
||||||
|
outputs.contrastive_loss = accelerator.gather(outputs.contrastive_loss).sum()
|
||||||
|
outputs.diversity_loss = accelerator.gather(outputs.diversity_loss).sum()
|
||||||
|
percent_masked = accelerator.gather(percent_masked).sum()
|
||||||
|
|
||||||
|
train_logs = {
|
||||||
|
"loss": (loss * args.gradient_accumulation_steps) / num_losses,
|
||||||
|
"constrast_loss": outputs.contrastive_loss / num_losses,
|
||||||
|
"div_loss": outputs.diversity_loss / num_losses,
|
||||||
|
"%_mask_idx": percent_masked / accelerator.num_processes,
|
||||||
|
"ppl": outputs.codevector_perplexity,
|
||||||
|
"lr": torch.tensor(optimizer.param_groups[0]["lr"]),
|
||||||
|
"temp": torch.tensor(gumbel_temperature),
|
||||||
|
"grad_norm": torch.tensor(grad_norm),
|
||||||
|
}
|
||||||
|
log_str = ""
|
||||||
|
for k, v in train_logs.items():
|
||||||
|
log_str += "| {}: {:.3e}".format(k, v.item())
|
||||||
|
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
progress_bar.write(log_str)
|
||||||
|
if is_wandb_available():
|
||||||
|
wandb.log(train_logs)
|
||||||
|
|
||||||
|
# save model every `args.saving_steps` steps
|
||||||
|
if (step + 1) % (args.gradient_accumulation_steps * args.saving_steps) == 0:
|
||||||
|
if (args.push_to_hub and epoch < args.num_train_epochs - 1) or args.output_dir is not None:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||||
|
|
||||||
|
if (args.push_to_hub and epoch < args.num_train_epochs - 1) and accelerator.is_main_process:
|
||||||
|
repo.push_to_hub(commit_message=f"Training in progress step {completed_steps}", blocking=False)
|
||||||
|
|
||||||
|
# if completed steps > `args.max_train_steps` stop
|
||||||
|
if completed_steps >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 7. Validate!
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# init logs
|
||||||
|
val_logs = {
|
||||||
|
"val_loss": 0,
|
||||||
|
"val_contrastive_loss": 0,
|
||||||
|
"val_diversity_loss": 0,
|
||||||
|
"val_num_losses": 0,
|
||||||
|
}
|
||||||
|
for step, batch in enumerate(eval_dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch.pop("sub_attention_mask", None)
|
||||||
|
outputs = model(**batch)
|
||||||
|
|
||||||
|
val_logs["val_loss"] += outputs.loss
|
||||||
|
val_logs["val_contrastive_loss"] += outputs.contrastive_loss
|
||||||
|
val_logs["val_diversity_loss"] += outputs.diversity_loss
|
||||||
|
val_logs["val_num_losses"] += batch["mask_time_indices"].sum()
|
||||||
|
|
||||||
|
# sum over devices in multi-processing
|
||||||
|
if accelerator.num_processes > 1:
|
||||||
|
val_logs = {k: accelerator.gather(v).sum() for k, v in val_logs.items()}
|
||||||
|
|
||||||
|
val_logs = {k: v / val_logs["val_num_losses"] for k, v in val_logs.items()}
|
||||||
|
|
||||||
|
log_str = ""
|
||||||
|
for k, v in val_logs.items():
|
||||||
|
log_str += "| {}: {:.3e}".format(k, v.item())
|
||||||
|
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
progress_bar.write(log_str)
|
||||||
|
if is_wandb_available():
|
||||||
|
wandb.log(val_logs)
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(args.output_dir, save_function=accelerator.save)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.push_to_hub:
|
||||||
|
repo.push_to_hub(commit_message="End of training")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -23,6 +23,7 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers import Wav2Vec2ForPreTraining
|
||||||
from transformers.file_utils import is_apex_available
|
from transformers.file_utils import is_apex_available
|
||||||
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
|
||||||
|
|
||||||
@ -41,6 +42,7 @@ SRC_DIRS = [
|
|||||||
"image-classification",
|
"image-classification",
|
||||||
"speech-recognition",
|
"speech-recognition",
|
||||||
"audio-classification",
|
"audio-classification",
|
||||||
|
"speech-pretraining",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
sys.path.extend(SRC_DIRS)
|
sys.path.extend(SRC_DIRS)
|
||||||
@ -59,6 +61,7 @@ if SRC_DIRS is not None:
|
|||||||
import run_summarization
|
import run_summarization
|
||||||
import run_swag
|
import run_swag
|
||||||
import run_translation
|
import run_translation
|
||||||
|
import run_wav2vec2_pretraining_no_trainer
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@ -447,3 +450,32 @@ class ExamplesTests(TestCasePlus):
|
|||||||
run_audio_classification.main()
|
run_audio_classification.main()
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertLess(result["eval_loss"], result["train_loss"])
|
self.assertLess(result["eval_loss"], result["train_loss"])
|
||||||
|
|
||||||
|
def test_run_wav2vec2_pretraining(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_wav2vec2_pretraining_no_trainer.py
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
|
||||||
|
--dataset_name patrickvonplaten/librispeech_asr_dummy
|
||||||
|
--dataset_config_names clean
|
||||||
|
--dataset_split_names validation
|
||||||
|
--learning_rate 1e-4
|
||||||
|
--per_device_train_batch_size 2
|
||||||
|
--per_device_eval_batch_size 2
|
||||||
|
--preprocessing_num_workers 16
|
||||||
|
--max_train_steps 5
|
||||||
|
--validation_split_percentage 5
|
||||||
|
--seed 42
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if is_cuda_and_apex_available():
|
||||||
|
testargs.append("--fp16")
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_wav2vec2_pretraining_no_trainer.main()
|
||||||
|
model = Wav2Vec2ForPreTraining.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
@ -48,13 +48,13 @@ def _compute_mask_indices(
|
|||||||
shape: Tuple[int, int],
|
shape: Tuple[int, int],
|
||||||
mask_prob: float,
|
mask_prob: float,
|
||||||
mask_length: int,
|
mask_length: int,
|
||||||
device: torch.device,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.tensor] = None,
|
|
||||||
min_masks: int = 0,
|
min_masks: int = 0,
|
||||||
) -> torch.tensor:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
||||||
ASR <https://arxiv.org/abs/1904.08779>`__.
|
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
|
||||||
|
on CPU as part of the preprocessing during training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: the the shape for which to compute masks.
|
shape: the the shape for which to compute masks.
|
||||||
@ -64,7 +64,6 @@ def _compute_mask_indices(
|
|||||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||||
mask_length: size of the mask
|
mask_length: size of the mask
|
||||||
min_masks: minimum number of masked spans
|
min_masks: minimum number of masked spans
|
||||||
|
|
||||||
"""
|
"""
|
||||||
batch_size, sequence_length = shape
|
batch_size, sequence_length = shape
|
||||||
|
|
||||||
@ -76,42 +75,64 @@ def _compute_mask_indices(
|
|||||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute number of masked spans in batch
|
epsilon = np.random.rand(1).item()
|
||||||
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
|
|
||||||
num_masked_spans = max(num_masked_spans, min_masks)
|
|
||||||
|
|
||||||
# make sure num masked indices <= sequence_length
|
def compute_num_masked_span(input_length):
|
||||||
if num_masked_spans * mask_length > sequence_length:
|
"""Given input length, compute how many spans should be masked"""
|
||||||
num_masked_spans = sequence_length // mask_length
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
||||||
|
num_masked_span = max(num_masked_span, min_masks)
|
||||||
|
|
||||||
|
# make sure num masked indices <= sequence_length
|
||||||
|
if num_masked_span * mask_length > sequence_length:
|
||||||
|
num_masked_span = sequence_length // mask_length
|
||||||
|
|
||||||
|
return num_masked_span
|
||||||
|
|
||||||
|
# compute number of masked spans in batch
|
||||||
|
input_lengths = (
|
||||||
|
attention_mask.sum(-1).detach().tolist()
|
||||||
|
if attention_mask is not None
|
||||||
|
else [sequence_length for _ in range(batch_size)]
|
||||||
|
)
|
||||||
|
|
||||||
# SpecAugment mask to fill
|
# SpecAugment mask to fill
|
||||||
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
|
||||||
|
spec_aug_mask_idxs = []
|
||||||
|
|
||||||
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
max_num_masked_span = compute_num_masked_span(sequence_length)
|
||||||
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
|
|
||||||
|
|
||||||
# get random indices to mask
|
for input_length in input_lengths:
|
||||||
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
|
# compute num of masked spans for this input
|
||||||
|
num_masked_span = compute_num_masked_span(input_length)
|
||||||
|
# get random indices to mask
|
||||||
|
spec_aug_mask_idx = np.random.choice(
|
||||||
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# pick first sampled index that will serve as a dummy index to pad vector
|
||||||
|
dummy_mask_idx = spec_aug_mask_idx[0]
|
||||||
|
|
||||||
|
spec_aug_mask_idx = np.concatenate(
|
||||||
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
||||||
|
)
|
||||||
|
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
||||||
|
|
||||||
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
||||||
|
|
||||||
# expand masked indices to masked spans
|
# expand masked indices to masked spans
|
||||||
spec_aug_mask_idxs = (
|
spec_aug_mask_idxs = np.broadcast_to(
|
||||||
spec_aug_mask_idxs.unsqueeze(dim=-1)
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
||||||
.expand((batch_size, num_masked_spans, mask_length))
|
|
||||||
.reshape(batch_size, num_masked_spans * mask_length)
|
|
||||||
)
|
)
|
||||||
offsets = (
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
||||||
torch.arange(mask_length, device=device)[None, None, :]
|
|
||||||
.expand((batch_size, num_masked_spans, mask_length))
|
offsets = np.arange(mask_length)[None, None, :]
|
||||||
.reshape(batch_size, num_masked_spans * mask_length)
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
||||||
|
batch_size, max_num_masked_span * mask_length
|
||||||
)
|
)
|
||||||
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||||||
|
|
||||||
# scatter indices to mask
|
# scatter indices to mask
|
||||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# make sure padded input ids cannot be masked
|
|
||||||
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
|
|
||||||
|
|
||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
@ -257,6 +278,7 @@ class HubertFeatureExtractor(nn.Module):
|
|||||||
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
||||||
)
|
)
|
||||||
self.conv_layers = nn.ModuleList(conv_layers)
|
self.conv_layers = nn.ModuleList(conv_layers)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def _freeze_parameters(self):
|
def _freeze_parameters(self):
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
@ -264,8 +286,26 @@ class HubertFeatureExtractor(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_values):
|
def forward(self, input_values):
|
||||||
hidden_states = input_values[:, None]
|
hidden_states = input_values[:, None]
|
||||||
|
|
||||||
|
# make sure hidden_states require grad for gradient_checkpointing
|
||||||
|
if self.training:
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
|
||||||
for conv_layer in self.conv_layers:
|
for conv_layer in self.conv_layers:
|
||||||
hidden_states = conv_layer(hidden_states)
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(conv_layer),
|
||||||
|
hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = conv_layer(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -864,10 +904,10 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
(batch_size, sequence_length),
|
(batch_size, sequence_length),
|
||||||
mask_prob=self.config.mask_time_prob,
|
mask_prob=self.config.mask_time_prob,
|
||||||
mask_length=self.config.mask_time_length,
|
mask_length=self.config.mask_time_length,
|
||||||
device=hidden_states.device,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
)
|
)
|
||||||
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
|
||||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||||
|
|
||||||
if self.config.mask_feature_prob > 0 and self.training:
|
if self.config.mask_feature_prob > 0 and self.training:
|
||||||
@ -876,9 +916,11 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
(batch_size, hidden_size),
|
(batch_size, hidden_size),
|
||||||
mask_prob=self.config.mask_feature_prob,
|
mask_prob=self.config.mask_feature_prob,
|
||||||
mask_length=self.config.mask_feature_length,
|
mask_length=self.config.mask_feature_length,
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
)
|
||||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
|
||||||
|
:, None
|
||||||
|
].expand(-1, sequence_length, -1)
|
||||||
|
hidden_states[mask_feature_indices] = 0
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Wav2Vec2 model. """
|
""" PyTorch Wav2Vec2 model. """
|
||||||
|
|
||||||
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
@ -87,7 +88,7 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
|||||||
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
|
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss (`optional`, returned when model is in train mode, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||||
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
|
Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the `official
|
||||||
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
|
paper <https://arxiv.org/pdf/2006.11477.pdf>`__ . (classification) loss.
|
||||||
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
|
projected_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.proj_codevector_dim)`):
|
||||||
@ -107,6 +108,10 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
heads.
|
heads.
|
||||||
|
contrastive_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||||
|
The contrastive loss (L_m) as stated in the `official paper <https://arxiv.org/pdf/2006.11477.pdf>`__ .
|
||||||
|
diversity_loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||||
|
The diversity loss (L_d) as stated in the `official paper <https://arxiv.org/pdf/2006.11477.pdf>`__ .
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@ -115,19 +120,21 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
|||||||
codevector_perplexity: torch.FloatTensor = None
|
codevector_perplexity: torch.FloatTensor = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
contrastive_loss: Optional[torch.FloatTensor] = None
|
||||||
|
diversity_loss: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
def _compute_mask_indices(
|
def _compute_mask_indices(
|
||||||
shape: Tuple[int, int],
|
shape: Tuple[int, int],
|
||||||
mask_prob: float,
|
mask_prob: float,
|
||||||
mask_length: int,
|
mask_length: int,
|
||||||
device: torch.device,
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.tensor] = None,
|
|
||||||
min_masks: int = 0,
|
min_masks: int = 0,
|
||||||
) -> torch.tensor:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
||||||
ASR <https://arxiv.org/abs/1904.08779>`__.
|
ASR <https://arxiv.org/abs/1904.08779>`__. Note that this method is not optimized to run on TPU and should be run
|
||||||
|
on CPU as part of the preprocessing during training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: the the shape for which to compute masks.
|
shape: the the shape for which to compute masks.
|
||||||
@ -137,7 +144,6 @@ def _compute_mask_indices(
|
|||||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||||
mask_length: size of the mask
|
mask_length: size of the mask
|
||||||
min_masks: minimum number of masked spans
|
min_masks: minimum number of masked spans
|
||||||
|
|
||||||
"""
|
"""
|
||||||
batch_size, sequence_length = shape
|
batch_size, sequence_length = shape
|
||||||
|
|
||||||
@ -149,46 +155,104 @@ def _compute_mask_indices(
|
|||||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute number of masked spans in batch
|
epsilon = np.random.rand(1).item()
|
||||||
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
|
|
||||||
num_masked_spans = max(num_masked_spans, min_masks)
|
|
||||||
|
|
||||||
# make sure num masked indices <= sequence_length
|
def compute_num_masked_span(input_length):
|
||||||
if num_masked_spans * mask_length > sequence_length:
|
"""Given input length, compute how many spans should be masked"""
|
||||||
num_masked_spans = sequence_length // mask_length
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
||||||
|
num_masked_span = max(num_masked_span, min_masks)
|
||||||
|
|
||||||
|
# make sure num masked indices <= sequence_length
|
||||||
|
if num_masked_span * mask_length > sequence_length:
|
||||||
|
num_masked_span = sequence_length // mask_length
|
||||||
|
|
||||||
|
return num_masked_span
|
||||||
|
|
||||||
|
# compute number of masked spans in batch
|
||||||
|
input_lengths = (
|
||||||
|
attention_mask.sum(-1).detach().tolist()
|
||||||
|
if attention_mask is not None
|
||||||
|
else [sequence_length for _ in range(batch_size)]
|
||||||
|
)
|
||||||
|
|
||||||
# SpecAugment mask to fill
|
# SpecAugment mask to fill
|
||||||
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
|
||||||
|
spec_aug_mask_idxs = []
|
||||||
|
|
||||||
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
max_num_masked_span = compute_num_masked_span(sequence_length)
|
||||||
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
|
|
||||||
|
|
||||||
# get random indices to mask
|
for input_length in input_lengths:
|
||||||
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
|
# compute num of masked spans for this input
|
||||||
|
num_masked_span = compute_num_masked_span(input_length)
|
||||||
|
# get random indices to mask
|
||||||
|
spec_aug_mask_idx = np.random.choice(
|
||||||
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# pick first sampled index that will serve as a dummy index to pad vector
|
||||||
|
dummy_mask_idx = spec_aug_mask_idx[0]
|
||||||
|
|
||||||
|
spec_aug_mask_idx = np.concatenate(
|
||||||
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
||||||
|
)
|
||||||
|
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
||||||
|
|
||||||
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
||||||
|
|
||||||
# expand masked indices to masked spans
|
# expand masked indices to masked spans
|
||||||
spec_aug_mask_idxs = (
|
spec_aug_mask_idxs = np.broadcast_to(
|
||||||
spec_aug_mask_idxs.unsqueeze(dim=-1)
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
||||||
.expand((batch_size, num_masked_spans, mask_length))
|
|
||||||
.reshape(batch_size, num_masked_spans * mask_length)
|
|
||||||
)
|
)
|
||||||
offsets = (
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
||||||
torch.arange(mask_length, device=device)[None, None, :]
|
|
||||||
.expand((batch_size, num_masked_spans, mask_length))
|
offsets = np.arange(mask_length)[None, None, :]
|
||||||
.reshape(batch_size, num_masked_spans * mask_length)
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
||||||
|
batch_size, max_num_masked_span * mask_length
|
||||||
)
|
)
|
||||||
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||||||
|
|
||||||
# scatter indices to mask
|
# scatter indices to mask
|
||||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# make sure padded input ids cannot be masked
|
|
||||||
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
|
|
||||||
|
|
||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_negative_indices(
|
||||||
|
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sample `num_negatives` vectors from feature vectors.
|
||||||
|
"""
|
||||||
|
batch_size, sequence_length = features_shape
|
||||||
|
|
||||||
|
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||||
|
sequence_length_range = np.arange(sequence_length)
|
||||||
|
|
||||||
|
# get `num_negatives` random vector indices from the same utterance
|
||||||
|
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
|
||||||
|
|
||||||
|
mask_time_indices = (
|
||||||
|
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
high = mask_time_indices[batch_idx].sum() - 1
|
||||||
|
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
|
||||||
|
|
||||||
|
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
|
||||||
|
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
|
||||||
|
# avoid sampling the same positive vector, but keep the distribution uniform
|
||||||
|
sampled_indices[sampled_indices >= feature_indices] += 1
|
||||||
|
|
||||||
|
# remap to actual indices
|
||||||
|
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
|
||||||
|
|
||||||
|
# correct for batch size
|
||||||
|
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
||||||
|
|
||||||
|
return sampled_negative_indices
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||||
def __init__(self, config, layer_id=0):
|
def __init__(self, config, layer_id=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -326,6 +390,7 @@ class Wav2Vec2FeatureExtractor(nn.Module):
|
|||||||
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
||||||
)
|
)
|
||||||
self.conv_layers = nn.ModuleList(conv_layers)
|
self.conv_layers = nn.ModuleList(conv_layers)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def _freeze_parameters(self):
|
def _freeze_parameters(self):
|
||||||
for param in self.parameters():
|
for param in self.parameters():
|
||||||
@ -333,8 +398,26 @@ class Wav2Vec2FeatureExtractor(nn.Module):
|
|||||||
|
|
||||||
def forward(self, input_values):
|
def forward(self, input_values):
|
||||||
hidden_states = input_values[:, None]
|
hidden_states = input_values[:, None]
|
||||||
|
|
||||||
|
# make sure hidden_states require grad for gradient_checkpointing
|
||||||
|
if self.training:
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
|
||||||
for conv_layer in self.conv_layers:
|
for conv_layer in self.conv_layers:
|
||||||
hidden_states = conv_layer(hidden_states)
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(conv_layer),
|
||||||
|
hidden_states,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = conv_layer(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -778,7 +861,7 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
|||||||
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
||||||
|
|
||||||
# can be decayed for training
|
# can be decayed for training
|
||||||
self.temperature = 1
|
self.temperature = 2
|
||||||
|
|
||||||
def set_temperature(self, temperature: int):
|
def set_temperature(self, temperature: int):
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
@ -844,8 +927,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = Wav2Vec2Config
|
config_class = Wav2Vec2Config
|
||||||
base_model_prefix = "wav2vec2"
|
base_model_prefix = "wav2vec2"
|
||||||
supports_gradient_checkpointing = True
|
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
@ -854,22 +937,31 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
||||||
module.weight_proj.bias.data.zero_()
|
module.weight_proj.bias.data.zero_()
|
||||||
nn.init.uniform_(module.codevectors)
|
nn.init.uniform_(module.codevectors)
|
||||||
|
elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
|
||||||
|
nn.init.normal_(
|
||||||
|
module.conv.weight,
|
||||||
|
mean=0,
|
||||||
|
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
||||||
|
)
|
||||||
|
nn.init.constant_(module.conv.bias, 0)
|
||||||
|
elif isinstance(module, Wav2Vec2FeatureProjection):
|
||||||
|
k = math.sqrt(1 / module.projection.in_features)
|
||||||
|
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
||||||
|
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
||||||
elif isinstance(module, nn.Linear):
|
elif isinstance(module, nn.Linear):
|
||||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
module.weight.data.fill_(1.0)
|
module.weight.data.fill_(1.0)
|
||||||
elif isinstance(module, nn.Conv1d):
|
elif isinstance(module, nn.Conv1d):
|
||||||
nn.init.kaiming_normal_(module.weight.data)
|
nn.init.kaiming_normal_(module.weight)
|
||||||
|
|
||||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||||||
|
nn.init.uniform_(module.bias, a=-k, b=k)
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)):
|
|
||||||
module.gradient_checkpointing = value
|
|
||||||
|
|
||||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||||
"""
|
"""
|
||||||
@ -898,6 +990,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureExtractor)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||||
@ -1001,10 +1097,10 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
(batch_size, sequence_length),
|
(batch_size, sequence_length),
|
||||||
mask_prob=self.config.mask_time_prob,
|
mask_prob=self.config.mask_time_prob,
|
||||||
mask_length=self.config.mask_time_length,
|
mask_length=self.config.mask_time_length,
|
||||||
device=hidden_states.device,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
)
|
)
|
||||||
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.long)
|
||||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||||
|
|
||||||
if self.config.mask_feature_prob > 0 and self.training:
|
if self.config.mask_feature_prob > 0 and self.training:
|
||||||
@ -1013,9 +1109,11 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
(batch_size, hidden_size),
|
(batch_size, hidden_size),
|
||||||
mask_prob=self.config.mask_feature_prob,
|
mask_prob=self.config.mask_feature_prob,
|
||||||
mask_length=self.config.mask_feature_length,
|
mask_length=self.config.mask_feature_length,
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
)
|
||||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.long)[
|
||||||
|
:, None
|
||||||
|
].expand(-1, sequence_length, -1)
|
||||||
|
hidden_states[mask_feature_indices] = 0
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -1101,11 +1199,13 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
|
||||||
|
|
||||||
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
|
self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
|
||||||
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
|
||||||
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
# make sure that project_hid & project_q are initialized like normal linear layers
|
||||||
|
self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
|
||||||
|
self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
|
||||||
|
|
||||||
def set_gumbel_temperature(self, temperature: int):
|
def set_gumbel_temperature(self, temperature: int):
|
||||||
"""
|
"""
|
||||||
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
Set the Gumbel softmax temperature to a given value. Only necessary for training
|
||||||
@ -1119,61 +1219,12 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _sample_negatives(
|
|
||||||
features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Sample `num_negatives` vectors from feature vectors.
|
|
||||||
"""
|
|
||||||
batch_size, sequence_length, hidden_size = features.shape
|
|
||||||
if sequence_length <= 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})."
|
|
||||||
)
|
|
||||||
|
|
||||||
features = features.view(-1, hidden_size) # BTC => (BxT)C
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# get `num_negatives` random vector indices from the same utterance
|
|
||||||
sampled_negative_indices = []
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
|
|
||||||
sampled_indices_slice = torch.randint(
|
|
||||||
0, high, size=(num_negatives * sequence_length,), device=features.device
|
|
||||||
)
|
|
||||||
sampled_negative_indices.append(sampled_indices_slice)
|
|
||||||
|
|
||||||
sampled_negative_indices = torch.stack(sampled_negative_indices)
|
|
||||||
|
|
||||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
|
||||||
feature_indices = (
|
|
||||||
torch.arange(sequence_length, device=features.device)[:, None]
|
|
||||||
.expand(sequence_length, num_negatives)
|
|
||||||
.flatten()
|
|
||||||
)
|
|
||||||
|
|
||||||
# avoid sampling the same positive vector, but keep the distribution uniform
|
|
||||||
sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1
|
|
||||||
|
|
||||||
# correct for batch size
|
|
||||||
for batch_idx in range(1, batch_size):
|
|
||||||
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
|
|
||||||
|
|
||||||
# take negative vectors from sampled indices
|
|
||||||
sampled_negatives = features[sampled_negative_indices.view(-1)]
|
|
||||||
sampled_negatives = sampled_negatives.view(batch_size, sequence_length, num_negatives, hidden_size).permute(
|
|
||||||
2, 0, 1, 3
|
|
||||||
)
|
|
||||||
|
|
||||||
return sampled_negatives
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_contrastive_logits(
|
def compute_contrastive_logits(
|
||||||
target_features: torch.FloatTensor,
|
target_features: torch.FloatTensor,
|
||||||
negative_features: torch.FloatTensor,
|
negative_features: torch.FloatTensor,
|
||||||
predicted_features: torch.FloatTensor,
|
predicted_features: torch.FloatTensor,
|
||||||
temperature: int = 1,
|
temperature: int = 0.1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
Compute logits for contrastive loss based using cosine similarity as the distance measure between
|
||||||
@ -1196,6 +1247,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
input_values,
|
input_values,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
mask_time_indices=None,
|
mask_time_indices=None,
|
||||||
|
sampled_negative_indices=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@ -1204,6 +1256,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
mask_time_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
|
||||||
masked extracted features in `config.proj_codevector_dim` space.
|
masked extracted features in `config.proj_codevector_dim` space.
|
||||||
|
sampled_negative_indices (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, sequence_length, num_negatives)`, `optional`):
|
||||||
|
Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
|
||||||
|
Required input for pre-training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@ -1270,21 +1325,30 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
# 2. quantize all (unmasked) extracted features and project to final vq dim
|
||||||
extract_features = self.dropout_features(outputs[1])
|
extract_features = self.dropout_features(outputs[1])
|
||||||
quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices)
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# compute reduced attention_mask correponding to feature vectors
|
||||||
|
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||||
|
|
||||||
|
quantized_features, codevector_perplexity = self.quantizer(
|
||||||
|
extract_features, mask_time_indices=mask_time_indices
|
||||||
|
)
|
||||||
quantized_features = self.project_q(quantized_features)
|
quantized_features = self.project_q(quantized_features)
|
||||||
|
|
||||||
loss = None
|
loss = contrastive_loss = diversity_loss = None
|
||||||
if self.training:
|
if sampled_negative_indices is not None:
|
||||||
|
batch_size, sequence_length, hidden_size = quantized_features.shape
|
||||||
|
|
||||||
# for training, we sample negatives
|
# for training, we sample negatives
|
||||||
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
||||||
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
||||||
if attention_mask is not None:
|
# sample negative quantized vectors BTC => (BxT)C
|
||||||
# compute reduced attention_mask correponding to feature vectors
|
negative_quantized_features = quantized_features.view(-1, hidden_size)[
|
||||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
sampled_negative_indices.long().view(-1)
|
||||||
|
]
|
||||||
negative_quantized_features = self._sample_negatives(
|
negative_quantized_features = negative_quantized_features.view(
|
||||||
quantized_features, self.config.num_negatives, attention_mask=attention_mask
|
batch_size, sequence_length, -1, hidden_size
|
||||||
)
|
).permute(2, 0, 1, 3)
|
||||||
|
|
||||||
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
||||||
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
||||||
@ -1298,18 +1362,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
# 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
|
||||||
# its cosine similarity will be masked
|
# its cosine similarity will be masked
|
||||||
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
|
||||||
|
|
||||||
if neg_is_pos.any():
|
if neg_is_pos.any():
|
||||||
logits[1:][neg_is_pos] = float("-inf")
|
logits[1:][neg_is_pos] = float("-inf")
|
||||||
|
|
||||||
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
# 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
|
||||||
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
# -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
|
||||||
preds = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
|
||||||
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
|
||||||
contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="sum")
|
|
||||||
|
|
||||||
|
contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
|
||||||
# 7. compute diversity loss: \mathbf{L}_d
|
# 7. compute diversity loss: \mathbf{L}_d
|
||||||
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
|
||||||
diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors
|
diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
|
||||||
|
|
||||||
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
# 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
|
||||||
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
|
||||||
@ -1326,6 +1391,8 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
codevector_perplexity=codevector_perplexity,
|
codevector_perplexity=codevector_perplexity,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
|
contrastive_loss=contrastive_loss,
|
||||||
|
diversity_loss=diversity_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -586,7 +586,8 @@ class HubertUtilsTest(unittest.TestCase):
|
|||||||
mask_prob = 0.5
|
mask_prob = 0.5
|
||||||
mask_length = 1
|
mask_length = 1
|
||||||
|
|
||||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||||
|
mask = torch.from_numpy(mask).to(torch_device)
|
||||||
|
|
||||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||||
|
|
||||||
@ -596,7 +597,8 @@ class HubertUtilsTest(unittest.TestCase):
|
|||||||
mask_prob = 0.5
|
mask_prob = 0.5
|
||||||
mask_length = 4
|
mask_length = 4
|
||||||
|
|
||||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||||
|
mask = torch.from_numpy(mask).to(torch_device)
|
||||||
|
|
||||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||||
for batch_sum in mask.sum(axis=-1):
|
for batch_sum in mask.sum(axis=-1):
|
||||||
|
@ -40,7 +40,11 @@ if is_torch_available():
|
|||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2Processor,
|
Wav2Vec2Processor,
|
||||||
)
|
)
|
||||||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2GumbelVectorQuantizer, _compute_mask_indices
|
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
||||||
|
Wav2Vec2GumbelVectorQuantizer,
|
||||||
|
_compute_mask_indices,
|
||||||
|
_sample_negative_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ModelTester:
|
class Wav2Vec2ModelTester:
|
||||||
@ -405,6 +409,12 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"masked_spec_embed",
|
"masked_spec_embed",
|
||||||
"codevectors",
|
"codevectors",
|
||||||
"quantizer.weight_proj.weight",
|
"quantizer.weight_proj.weight",
|
||||||
|
"project_hid.weight",
|
||||||
|
"project_hid.bias",
|
||||||
|
"project_q.weight",
|
||||||
|
"project_q.bias",
|
||||||
|
"feature_projection.projection.weight",
|
||||||
|
"feature_projection.projection.bias",
|
||||||
]
|
]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any([x in name for x in uniform_init_parms]):
|
||||||
@ -605,6 +615,12 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"masked_spec_embed",
|
"masked_spec_embed",
|
||||||
"codevectors",
|
"codevectors",
|
||||||
"quantizer.weight_proj.weight",
|
"quantizer.weight_proj.weight",
|
||||||
|
"project_hid.weight",
|
||||||
|
"project_hid.bias",
|
||||||
|
"project_q.weight",
|
||||||
|
"project_q.bias",
|
||||||
|
"feature_projection.projection.weight",
|
||||||
|
"feature_projection.projection.bias",
|
||||||
]
|
]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any([x in name for x in uniform_init_parms]):
|
||||||
@ -640,28 +656,37 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
features_shape = (
|
features_shape = (
|
||||||
inputs_dict["input_values"].shape[0],
|
inputs_dict["input_values"].shape[0],
|
||||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_time_indices = _compute_mask_indices(
|
mask_time_indices = _compute_mask_indices(
|
||||||
features_shape,
|
features_shape,
|
||||||
model.config.mask_time_prob,
|
model.config.mask_time_prob,
|
||||||
model.config.mask_time_length,
|
model.config.mask_time_length,
|
||||||
device=inputs_dict["input_values"].device,
|
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
).to(torch_device)
|
)
|
||||||
|
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices)
|
||||||
|
|
||||||
|
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||||
|
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||||
|
|
||||||
loss = model(
|
loss = model(
|
||||||
inputs_dict["input_values"],
|
inputs_dict["input_values"],
|
||||||
attention_mask=inputs_dict["attention_mask"],
|
attention_mask=inputs_dict["attention_mask"],
|
||||||
mask_time_indices=mask_time_indices,
|
mask_time_indices=mask_time_indices,
|
||||||
|
sampled_negative_indices=sampled_negative_indices,
|
||||||
).loss
|
).loss
|
||||||
|
|
||||||
|
# more losses
|
||||||
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
|
mask_time_indices[:, : mask_time_indices.shape[-1] // 2] = True
|
||||||
|
|
||||||
|
sampled_negative_indices = _sample_negative_indices(features_shape, 10, mask_time_indices.cpu().numpy())
|
||||||
|
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||||
loss_more_masked = model(
|
loss_more_masked = model(
|
||||||
inputs_dict["input_values"],
|
inputs_dict["input_values"],
|
||||||
attention_mask=inputs_dict["attention_mask"],
|
attention_mask=inputs_dict["attention_mask"],
|
||||||
mask_time_indices=mask_time_indices,
|
mask_time_indices=mask_time_indices,
|
||||||
|
sampled_negative_indices=sampled_negative_indices,
|
||||||
).loss
|
).loss
|
||||||
|
|
||||||
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
# loss_more_masked has to be bigger or equal loss since more masked inputs have to be predicted
|
||||||
@ -727,7 +752,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
mask_prob = 0.5
|
mask_prob = 0.5
|
||||||
mask_length = 1
|
mask_length = 1
|
||||||
|
|
||||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||||
|
mask = torch.from_numpy(mask).to(torch_device)
|
||||||
|
|
||||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||||
|
|
||||||
@ -737,7 +763,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
mask_prob = 0.5
|
mask_prob = 0.5
|
||||||
mask_length = 4
|
mask_length = 4
|
||||||
|
|
||||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||||
|
mask = torch.from_numpy(mask).to(torch_device)
|
||||||
|
|
||||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||||
for batch_sum in mask.sum(axis=-1):
|
for batch_sum in mask.sum(axis=-1):
|
||||||
@ -753,8 +780,9 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
attention_mask[:2, sequence_length // 2 :] = 0
|
attention_mask[:2, sequence_length // 2 :] = 0
|
||||||
|
|
||||||
mask = _compute_mask_indices(
|
mask = _compute_mask_indices(
|
||||||
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
|
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||||
)
|
)
|
||||||
|
mask = torch.from_numpy(mask).to(torch_device)
|
||||||
|
|
||||||
for batch_sum in mask.sum(axis=-1):
|
for batch_sum in mask.sum(axis=-1):
|
||||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||||
@ -785,8 +813,11 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
) # each value in vector consits of same value
|
) # each value in vector consits of same value
|
||||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||||
|
|
||||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives)
|
# sample negative indices
|
||||||
|
sampled_negative_indices = _sample_negative_indices((batch_size, sequence_length), num_negatives, None)
|
||||||
|
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||||
|
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
|
||||||
|
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
|
||||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||||
|
|
||||||
# make sure no negatively sampled vector is actually a positive one
|
# make sure no negatively sampled vector is actually a positive one
|
||||||
@ -796,15 +827,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||||
|
|
||||||
def test_sample_negatives_with_attn_mask(self):
|
def test_sample_negatives_with_mask(self):
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
sequence_length = 10
|
sequence_length = 10
|
||||||
hidden_size = 4
|
hidden_size = 4
|
||||||
num_negatives = 3
|
num_negatives = 3
|
||||||
|
|
||||||
# second half of last input tensor is padded
|
# second half of last input tensor is padded
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||||
attention_mask[-1, sequence_length // 2 :] = 0
|
mask[-1, sequence_length // 2 :] = 0
|
||||||
|
|
||||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||||
sequence_length, hidden_size
|
sequence_length, hidden_size
|
||||||
@ -812,9 +843,15 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||||
|
|
||||||
# replace masked feature vectors with -100 to test that those are not sampled
|
# replace masked feature vectors with -100 to test that those are not sampled
|
||||||
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
|
features = torch.where(mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||||
|
|
||||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
|
# sample negative indices
|
||||||
|
sampled_negative_indices = _sample_negative_indices(
|
||||||
|
(batch_size, sequence_length), num_negatives, mask.cpu().numpy()
|
||||||
|
)
|
||||||
|
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||||
|
negatives = features.view(-1, hidden_size)[sampled_negative_indices.long().view(-1)]
|
||||||
|
negatives = negatives.view(batch_size, sequence_length, -1, hidden_size).permute(2, 0, 1, 3)
|
||||||
|
|
||||||
self.assertTrue((negatives >= 0).all().item())
|
self.assertTrue((negatives >= 0).all().item())
|
||||||
|
|
||||||
@ -924,16 +961,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
# Wav2Vec2 pretraining seems to be broken. TODO(PVP) - reenable test once pretraining works
|
@unittest.skipIf(torch_device != "cpu", "cannot make deterministic on GPU")
|
||||||
# correctly
|
|
||||||
def test_inference_integration(self):
|
def test_inference_integration(self):
|
||||||
return
|
|
||||||
|
|
||||||
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||||
"facebook/wav2vec2-base", return_attention_mask=True
|
|
||||||
)
|
|
||||||
input_speech = self._load_datasamples(2)
|
input_speech = self._load_datasamples(2)
|
||||||
|
|
||||||
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
|
||||||
@ -943,19 +975,18 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
np.random.seed(4)
|
||||||
mask_time_indices = _compute_mask_indices(
|
mask_time_indices = _compute_mask_indices(
|
||||||
features_shape,
|
features_shape,
|
||||||
model.config.mask_time_prob,
|
model.config.mask_time_prob,
|
||||||
model.config.mask_time_length,
|
model.config.mask_time_length,
|
||||||
device=inputs_dict["input_values"].device,
|
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
).to(torch_device)
|
)
|
||||||
|
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(
|
outputs = model(
|
||||||
inputs_dict.input_values.to(torch_device),
|
inputs_dict.input_values.to(torch_device),
|
||||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
|
||||||
mask_time_indices=mask_time_indices,
|
mask_time_indices=mask_time_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -965,14 +996,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
# retrieve cosine sim of masked features
|
# retrieve cosine sim of masked features
|
||||||
cosine_sim_masked = cosine_sim[mask_time_indices]
|
cosine_sim_masked = cosine_sim[mask_time_indices]
|
||||||
|
|
||||||
|
# cosine similarity of model is all > 0.5 as model is
|
||||||
|
# pre-trained on contrastive loss
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected_cosine_sim_masked = torch.tensor(
|
expected_cosine_sim_masked = torch.tensor([
|
||||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
|
0.8523, 0.5860, 0.6905, 0.5557, 0.7456, 0.5249, 0.6639, 0.7654, 0.7565,
|
||||||
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
|
0.8167, 0.8222, 0.7960, 0.8034, 0.8166, 0.8310, 0.8263, 0.8274, 0.8258,
|
||||||
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
|
0.8179, 0.8412, 0.8536, 0.5098, 0.4728, 0.6461, 0.4498, 0.6002, 0.5774,
|
||||||
0.6997],
|
0.6457, 0.7123, 0.5668, 0.6866, 0.4960, 0.6293, 0.7423, 0.7419, 0.7526,
|
||||||
device=torch_device,
|
0.7768, 0.4898, 0.5393, 0.8183
|
||||||
)
|
], device=torch_device)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
self.assertTrue(torch.allclose(cosine_sim_masked, expected_cosine_sim_masked, atol=1e-3))
|
||||||
@ -997,9 +1030,9 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
features_shape,
|
features_shape,
|
||||||
model.config.mask_time_prob,
|
model.config.mask_time_prob,
|
||||||
model.config.mask_time_length,
|
model.config.mask_time_length,
|
||||||
device=inputs_dict["input_values"].device,
|
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
).to(torch_device)
|
)
|
||||||
|
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(
|
outputs = model(
|
||||||
@ -1064,28 +1097,36 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
|
||||||
mask_time_indices = _compute_mask_indices(
|
mask_time_indices = _compute_mask_indices(
|
||||||
features_shape,
|
features_shape,
|
||||||
model.config.mask_time_prob,
|
model.config.mask_time_prob,
|
||||||
model.config.mask_time_length,
|
model.config.mask_time_length,
|
||||||
device=inputs_dict["input_values"].device,
|
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
).to(torch_device)
|
)
|
||||||
|
sampled_negative_indices = _sample_negative_indices(
|
||||||
|
mask_time_indices.shape, model.config.num_negatives, mask_time_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_time_indices = torch.from_numpy(mask_time_indices).to(torch_device)
|
||||||
|
sampled_negative_indices = torch.from_numpy(sampled_negative_indices).to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(
|
outputs = model(
|
||||||
inputs_dict.input_values.to(torch_device),
|
inputs_dict.input_values.to(torch_device),
|
||||||
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
attention_mask=inputs_dict.attention_mask.to(torch_device),
|
||||||
mask_time_indices=mask_time_indices,
|
mask_time_indices=mask_time_indices,
|
||||||
|
sampled_negative_indices=sampled_negative_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# check diversity loss
|
# check diversity loss
|
||||||
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups
|
||||||
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors
|
||||||
self.assertTrue(abs(diversity_loss.item() - 0.8859) < 1e-3)
|
self.assertTrue(abs(diversity_loss.item() - 0.9538) < 1e-3)
|
||||||
|
|
||||||
# check overall loss (contrastive loss + diversity loss)
|
# check overall loss (contrastive loss + diversity loss)
|
||||||
expected_loss = 62.5170
|
expected_loss = 116.7094
|
||||||
|
|
||||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user