mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add Descript-Audio-Codec model (#31494)
* dac model * original dac works * add dac model * dac can be instatiated * add forward pass * load weights * all weights are used * convert checkpoint script ready * test * add feature extractor * up * make style * apply cookicutter * fix tests * iterate on FeatureExtractor * nit * update dac doc * replace nn.Sequential with nn.ModuleList * nit * apply review suggestions 1/2 * Update src/transformers/models/dac/modeling_dac.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * up * apply review suggestions 2/2 * update padding in FeatureExtractor * apply review suggestions * iterate on design and tests * add integration tests * feature extractor tests * make style * all tests pass * make style * fixup * apply review suggestions * fix-copies * apply review suggestions * apply review suggestions * Update docs/source/en/model_doc/dac.md Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * Update docs/source/en/model_doc/dac.md Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * anticipate transfer weights to descript * up * make style * apply review suggestions * update slow test values * update slow tests * update test values * update with CI values * update with vorace values * update test with slice * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
parent
843e5e20ca
commit
8260cb311e
@ -696,6 +696,8 @@
|
||||
title: Bark
|
||||
- local: model_doc/clap
|
||||
title: CLAP
|
||||
- local: model_doc/dac
|
||||
title: dac
|
||||
- local: model_doc/encodec
|
||||
title: EnCodec
|
||||
- local: model_doc/hiera
|
||||
|
@ -105,6 +105,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
|
||||
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
|
||||
| [CvT](model_doc/cvt) | ✅ | ✅ | ❌ |
|
||||
| [DAC](model_doc/dac) | ✅ | ❌ | ❌ |
|
||||
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
||||
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
||||
| [Data2VecVision](model_doc/data2vec) | ✅ | ✅ | ❌ |
|
||||
|
80
docs/source/en/model_doc/dac.md
Normal file
80
docs/source/en/model_doc/dac.md
Normal file
@ -0,0 +1,80 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# DAC
|
||||
|
||||
## Overview
|
||||
|
||||
|
||||
The DAC model was proposed in [Descript Audio Codec: High-Fidelity Audio Compression with Improved RVQGAN](https://arxiv.org/abs/2306.06546) by Rithesh Kumar, Prem Seetharaman, Alejandro Luebs, Ishaan Kumar, Kundan Kumar.
|
||||
|
||||
The Descript Audio Codec (DAC) model is a powerful tool for compressing audio data, making it highly efficient for storage and transmission. By compressing 44.1 KHz audio into tokens at just 8kbps bandwidth, the DAC model enables high-quality audio processing while significantly reducing the data footprint. This is particularly useful in scenarios where bandwidth is limited or storage space is at a premium, such as in streaming applications, remote conferencing, and archiving large audio datasets.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Language models have been successfully used to model natural signals, such as images, speech, and music. A key component of these models is a high quality neural compression model that can compress high-dimensional natural signals into lower dimensional discrete tokens. To that end, we introduce a high-fidelity universal neural audio compression algorithm that achieves ~90x compression of 44.1 KHz audio into tokens at just 8kbps bandwidth. We achieve this by combining advances in high-fidelity audio generation with better vector quantization techniques from the image domain, along with improved adversarial and reconstruction losses. We compress all domains (speech, environment, music, etc.) with a single universal model, making it widely applicable to generative modeling of all audio. We compare with competing audio compression algorithms, and find our method outperforms them significantly. We provide thorough ablations for every design choice, as well as open-source code and trained model weights. We hope our work can lay the foundation for the next generation of high-fidelity audio modeling.*
|
||||
|
||||
This model was contributed by [Kamil Akesbi](https://huggingface.co/kamilakesbi).
|
||||
The original code can be found [here](https://github.com/descriptinc/descript-audio-codec/tree/main?tab=readme-ov-file).
|
||||
|
||||
|
||||
## Model structure
|
||||
|
||||
The Descript Audio Codec (DAC) model is structured into three distinct stages:
|
||||
|
||||
1. Encoder Model: This stage compresses the input audio, reducing its size while retaining essential information.
|
||||
2. Residual Vector Quantizer (RVQ) Model: Working in tandem with the encoder, this model quantizes the latent codes of the audio, refining the compression and ensuring high-quality reconstruction.
|
||||
3. Decoder Model: This final stage reconstructs the audio from its compressed form, restoring it to a state that closely resembles the original input.
|
||||
|
||||
## Usage example
|
||||
|
||||
Here is a quick example of how to encode and decode an audio using this model:
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset, Audio
|
||||
>>> from transformers import DacModel, AutoProcessor
|
||||
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
>>> model = DacModel.from_pretrained("descript/dac_16khz")
|
||||
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
|
||||
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
|
||||
|
||||
>>> encoder_outputs = model.encode(inputs["input_values"])
|
||||
>>> # Get the intermediate audio codes
|
||||
>>> audio_codes = encoder_outputs.audio_codes
|
||||
>>> # Reconstruct the audio from its quantized representation
|
||||
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
|
||||
>>> # or the equivalent with a forward pass
|
||||
>>> audio_values = model(inputs["input_values"]).audio_values
|
||||
```
|
||||
|
||||
## DacConfig
|
||||
|
||||
[[autodoc]] DacConfig
|
||||
|
||||
## DacFeatureExtractor
|
||||
|
||||
[[autodoc]] DacFeatureExtractor
|
||||
- __call__
|
||||
|
||||
## DacModel
|
||||
|
||||
[[autodoc]] DacModel
|
||||
- decode
|
||||
- encode
|
||||
- forward
|
@ -312,6 +312,7 @@ _import_structure = {
|
||||
"CTRLTokenizer",
|
||||
],
|
||||
"models.cvt": ["CvtConfig"],
|
||||
"models.dac": ["DacConfig", "DacFeatureExtractor"],
|
||||
"models.data2vec": [
|
||||
"Data2VecAudioConfig",
|
||||
"Data2VecTextConfig",
|
||||
@ -1757,6 +1758,12 @@ else:
|
||||
"CvtPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.dac"].extend(
|
||||
[
|
||||
"DacModel",
|
||||
"DacPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.data2vec"].extend(
|
||||
[
|
||||
"Data2VecAudioForAudioFrameClassification",
|
||||
@ -5026,6 +5033,10 @@ if TYPE_CHECKING:
|
||||
CTRLTokenizer,
|
||||
)
|
||||
from .models.cvt import CvtConfig
|
||||
from .models.dac import (
|
||||
DacConfig,
|
||||
DacFeatureExtractor,
|
||||
)
|
||||
from .models.data2vec import (
|
||||
Data2VecAudioConfig,
|
||||
Data2VecTextConfig,
|
||||
@ -6450,6 +6461,10 @@ if TYPE_CHECKING:
|
||||
CvtModel,
|
||||
CvtPreTrainedModel,
|
||||
)
|
||||
from .models.dac import (
|
||||
DacModel,
|
||||
DacPreTrainedModel,
|
||||
)
|
||||
from .models.data2vec import (
|
||||
Data2VecAudioForAudioFrameClassification,
|
||||
Data2VecAudioForCTC,
|
||||
|
@ -59,6 +59,7 @@ from . import (
|
||||
cpmant,
|
||||
ctrl,
|
||||
cvt,
|
||||
dac,
|
||||
data2vec,
|
||||
dbrx,
|
||||
deberta,
|
||||
|
@ -73,6 +73,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("cpmant", "CpmAntConfig"),
|
||||
("ctrl", "CTRLConfig"),
|
||||
("cvt", "CvtConfig"),
|
||||
("dac", "DacConfig"),
|
||||
("data2vec-audio", "Data2VecAudioConfig"),
|
||||
("data2vec-text", "Data2VecTextConfig"),
|
||||
("data2vec-vision", "Data2VecVisionConfig"),
|
||||
@ -354,6 +355,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("cpmant", "CPM-Ant"),
|
||||
("ctrl", "CTRL"),
|
||||
("cvt", "CvT"),
|
||||
("dac", "DAC"),
|
||||
("data2vec-audio", "Data2VecAudio"),
|
||||
("data2vec-text", "Data2VecText"),
|
||||
("data2vec-vision", "Data2VecVision"),
|
||||
|
@ -49,6 +49,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("conditional_detr", "ConditionalDetrFeatureExtractor"),
|
||||
("convnext", "ConvNextFeatureExtractor"),
|
||||
("cvt", "ConvNextFeatureExtractor"),
|
||||
("dac", "DacFeatureExtractor"),
|
||||
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
|
||||
("data2vec-vision", "BeitFeatureExtractor"),
|
||||
("deformable_detr", "DeformableDetrFeatureExtractor"),
|
||||
|
@ -73,6 +73,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("cpmant", "CpmAntModel"),
|
||||
("ctrl", "CTRLModel"),
|
||||
("cvt", "CvtModel"),
|
||||
("dac", "DacModel"),
|
||||
("data2vec-audio", "Data2VecAudioModel"),
|
||||
("data2vec-text", "Data2VecTextModel"),
|
||||
("data2vec-vision", "Data2VecVisionModel"),
|
||||
|
60
src/transformers/models/dac/__init__.py
Normal file
60
src/transformers/models/dac/__init__.py
Normal file
@ -0,0 +1,60 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_dac": ["DacConfig"],
|
||||
"feature_extraction_dac": ["DacFeatureExtractor"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_dac"] = [
|
||||
"DacModel",
|
||||
"DacPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_dac import (
|
||||
DacConfig,
|
||||
)
|
||||
from .feature_extraction_dac import DacFeatureExtractor
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_dac import (
|
||||
DacModel,
|
||||
DacPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
111
src/transformers/models/dac/configuration_dac.py
Normal file
111
src/transformers/models/dac/configuration_dac.py
Normal file
@ -0,0 +1,111 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dac model configuration"""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DacConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an [`DacModel`]. It is used to instantiate a
|
||||
Dac model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the
|
||||
[descript/dac_16khz](https://huggingface.co/descript/dac_16khz) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
encoder_hidden_size (`int`, *optional*, defaults to 64):
|
||||
Intermediate representation dimension for the encoder.
|
||||
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 8, 8]`):
|
||||
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
|
||||
decoder_hidden_size (`int`, *optional*, defaults to 1536):
|
||||
Intermediate representation dimension for the decoder.
|
||||
n_codebooks (`int`, *optional*, defaults to 9):
|
||||
Number of codebooks in the VQVAE.
|
||||
codebook_size (`int`, *optional*, defaults to 1024):
|
||||
Number of discrete codes in each codebook.
|
||||
codebook_dim (`int`, *optional*, defaults to 8):
|
||||
Dimension of the codebook vectors. If not defined, uses `encoder_hidden_size`.
|
||||
quantizer_dropout (`bool`, *optional*, defaults to 0):
|
||||
Whether to apply dropout to the quantizer.
|
||||
commitment_loss_weight (float, *optional*, defaults to 0.25):
|
||||
Weight of the commitment loss term in the VQVAE loss function.
|
||||
codebook_loss_weight (float, *optional*, defaults to 1.0):
|
||||
Weight of the codebook loss term in the VQVAE loss function.
|
||||
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DacModel, DacConfig
|
||||
|
||||
>>> # Initializing a "descript/dac_16khz" style configuration
|
||||
>>> configuration = DacConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the "descript/dac_16khz" style configuration
|
||||
>>> model = DacModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "dac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_hidden_size=64,
|
||||
downsampling_ratios=[2, 4, 8, 8],
|
||||
decoder_hidden_size=1536,
|
||||
n_codebooks=9,
|
||||
codebook_size=1024,
|
||||
codebook_dim=8,
|
||||
quantizer_dropout=0,
|
||||
commitment_loss_weight=0.25,
|
||||
codebook_loss_weight=1.0,
|
||||
sampling_rate=16000,
|
||||
**kwargs,
|
||||
):
|
||||
self.encoder_hidden_size = encoder_hidden_size
|
||||
self.downsampling_ratios = downsampling_ratios
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.upsampling_ratios = downsampling_ratios[::-1]
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
self.hidden_size = encoder_hidden_size * (2 ** len(downsampling_ratios))
|
||||
|
||||
self.hop_length = int(np.prod(downsampling_ratios))
|
||||
self.commitment_loss_weight = commitment_loss_weight
|
||||
self.codebook_loss_weight = codebook_loss_weight
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def frame_rate(self) -> int:
|
||||
hop_length = np.prod(self.upsampling_ratios)
|
||||
return math.ceil(self.sampling_rate / hop_length)
|
261
src/transformers/models/dac/convert_dac_checkpoint.py
Normal file
261
src/transformers/models/dac/convert_dac_checkpoint.py
Normal file
@ -0,0 +1,261 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import fnmatch
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
DacConfig,
|
||||
DacFeatureExtractor,
|
||||
DacModel,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
# checkpoints downloaded using:
|
||||
# pip install descript-audio-codec
|
||||
# python3 -m dac download # downloads the default 44kHz variant
|
||||
# python3 -m dac download --model_type 44khz # downloads the 44kHz variant
|
||||
# python3 -m dac download --model_type 24khz # downloads the 24kHz variant
|
||||
# python3 -m dac download --model_type 16khz # downloads the 16kHz variant
|
||||
# More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.models.dac")
|
||||
|
||||
|
||||
def match_pattern(string, pattern):
|
||||
# Split the pattern into parts
|
||||
pattern_parts = pattern.split(".")
|
||||
string_parts = string.split(".")
|
||||
|
||||
pattern_block_count = string_block_count = 0
|
||||
|
||||
for part in pattern_parts:
|
||||
if part.startswith("block"):
|
||||
pattern_block_count += 1
|
||||
|
||||
for part in string_parts:
|
||||
if part.startswith("block"):
|
||||
string_block_count += 1
|
||||
|
||||
return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
|
||||
|
||||
|
||||
TOP_LEVEL_KEYS = []
|
||||
IGNORE_KEYS = []
|
||||
|
||||
|
||||
MAPPING_ENCODER = {
|
||||
"encoder.block.0": ["encoder.conv1"],
|
||||
"encoder.block.5": ["encoder.snake1"],
|
||||
"encoder.block.6": ["encoder.conv2"],
|
||||
"encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
|
||||
"encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
|
||||
"encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
|
||||
"encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
|
||||
"encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
|
||||
"encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
|
||||
}
|
||||
|
||||
MAPPING_QUANTIZER = {
|
||||
"quantizer.quantizers.*": ["quantizer.quantizers.*"],
|
||||
}
|
||||
|
||||
MAPPING_DECODER = {
|
||||
"decoder.model.0": ["decoder.conv1"],
|
||||
"decoder.model.5": ["decoder.snake1"],
|
||||
"decoder.model.6": ["decoder.conv2"],
|
||||
"decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
|
||||
"decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
|
||||
"decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
|
||||
"decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
|
||||
"decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
|
||||
"decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
|
||||
}
|
||||
|
||||
|
||||
MAPPING = {
|
||||
**MAPPING_ENCODER,
|
||||
**MAPPING_QUANTIZER,
|
||||
**MAPPING_DECODER,
|
||||
}
|
||||
|
||||
|
||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
if hf_shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
||||
f" {value.shape} for {full_name}"
|
||||
)
|
||||
|
||||
if weight_type == "weight":
|
||||
hf_pointer.weight.data = value
|
||||
elif weight_type == "weight_g":
|
||||
hf_pointer.weight_g.data = value
|
||||
elif weight_type == "weight_v":
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
elif weight_type == "alpha":
|
||||
hf_pointer.alpha.data = value
|
||||
logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def should_ignore(name, ignore_keys):
|
||||
for key in ignore_keys:
|
||||
if key.endswith(".*"):
|
||||
if name.startswith(key[:-1]):
|
||||
return True
|
||||
elif ".*." in key:
|
||||
prefix, suffix = key.split(".*.")
|
||||
if prefix in name and suffix in name:
|
||||
return True
|
||||
elif key in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def recursively_load_weights(orig_dict, hf_model, model_name):
|
||||
unused_weights = []
|
||||
|
||||
if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
|
||||
raise ValueError(f"Unsupported model: {model_name}")
|
||||
|
||||
for name, value in orig_dict.items():
|
||||
is_used = False
|
||||
for key, mapped_key in MAPPING.items():
|
||||
regex = re.compile(key)
|
||||
if regex.search(name):
|
||||
if len(mapped_key) == 1:
|
||||
if mapped_key[0][0] == "q":
|
||||
mapped_key = ".".join(name.split(".")[:-1])
|
||||
else:
|
||||
mapped_key = mapped_key[0]
|
||||
elif len(mapped_key) == 3:
|
||||
integers = re.findall(r"\b\d+\b", name)
|
||||
if mapped_key[0][0] == "d":
|
||||
mapped_key = "{}.{}.{}{}.{}".format(
|
||||
mapped_key[0],
|
||||
str(int(integers[0]) - 1),
|
||||
mapped_key[1],
|
||||
str(int(integers[1]) - 1),
|
||||
mapped_key[2],
|
||||
)
|
||||
else:
|
||||
mapped_key = "{}.{}.{}{}.{}".format(
|
||||
mapped_key[0],
|
||||
str(int(integers[0]) - 1),
|
||||
mapped_key[1],
|
||||
str(int(integers[1]) + 1),
|
||||
mapped_key[2],
|
||||
)
|
||||
elif len(mapped_key) == 2:
|
||||
integers = re.findall(r"\b\d+\b", name)
|
||||
mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
|
||||
|
||||
is_used = True
|
||||
if "weight_g" in name:
|
||||
weight_type = "weight_g"
|
||||
elif "weight_v" in name:
|
||||
weight_type = "weight_v"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
elif "alpha" in name:
|
||||
weight_type = "alpha"
|
||||
elif "weight" in name:
|
||||
weight_type = "weight"
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
|
||||
if not is_used:
|
||||
unused_weights.append(name)
|
||||
|
||||
print(list(set(unused_weights)))
|
||||
|
||||
logger.warning(f"Unused weights: {unused_weights}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_checkpoint(
|
||||
model_name,
|
||||
checkpoint_path,
|
||||
pytorch_dump_folder_path,
|
||||
sample_rate=16000,
|
||||
repo_id=None,
|
||||
):
|
||||
model_dict = torch.load(checkpoint_path, "cpu")
|
||||
|
||||
config = DacConfig()
|
||||
|
||||
metadata = model_dict["metadata"]["kwargs"]
|
||||
config.encoder_hidden_size = metadata["encoder_dim"]
|
||||
config.downsampling_ratios = metadata["encoder_rates"]
|
||||
config.codebook_size = metadata["codebook_size"]
|
||||
config.n_codebooks = metadata["n_codebooks"]
|
||||
config.codebook_dim = metadata["codebook_dim"]
|
||||
config.decoder_hidden_size = metadata["decoder_dim"]
|
||||
config.upsampling_ratios = metadata["decoder_rates"]
|
||||
config.quantizer_dropout = float(metadata["quantizer_dropout"])
|
||||
config.sampling_rate = sample_rate
|
||||
|
||||
model = DacModel(config)
|
||||
feature_extractor = DacFeatureExtractor()
|
||||
feature_extractor.sampling_rate = sample_rate
|
||||
|
||||
original_checkpoint = model_dict["state_dict"]
|
||||
|
||||
model.apply_weight_norm()
|
||||
recursively_load_weights(original_checkpoint, model, model_name)
|
||||
model.remove_weight_norm()
|
||||
|
||||
model.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
if repo_id:
|
||||
print("Pushing to the hub...")
|
||||
feature_extractor.push_to_hub(repo_id)
|
||||
model.push_to_hub(repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="dac_44khz",
|
||||
type=str,
|
||||
help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
|
||||
)
|
||||
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
||||
)
|
||||
parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_checkpoint(
|
||||
args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
|
||||
)
|
170
src/transformers/models/dac/feature_extraction_dac.py
Normal file
170
src/transformers/models/dac/feature_extraction_dac.py
Normal file
@ -0,0 +1,170 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for DAC"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DacFeatureExtractor(SequenceFeatureExtractor):
|
||||
r"""
|
||||
Constructs an Dac feature extractor.
|
||||
|
||||
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
||||
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 1):
|
||||
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
|
||||
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
The value that is used for padding.
|
||||
hop_length (`int`, *optional*, defaults to 512):
|
||||
Overlap length between successive windows.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_values", "n_quantizers"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size: int = 1,
|
||||
sampling_rate: int = 16000,
|
||||
padding_value: float = 0.0,
|
||||
hop_length: int = 512,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
||||
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
|
||||
truncation: Optional[bool] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several sequence(s).
|
||||
|
||||
Args:
|
||||
raw_audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
|
||||
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
|
||||
(`feature_size = 2`).
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
truncation (`bool`, *optional*, defaults to `False`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return Numpy `np.ndarray` objects.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors.
|
||||
"""
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
|
||||
f" {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
if padding and truncation:
|
||||
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
|
||||
elif padding is None:
|
||||
# by default let's pad the inputs
|
||||
padding = True
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
|
||||
elif not is_batched and not isinstance(raw_audio, np.ndarray):
|
||||
raw_audio = np.asarray(raw_audio, dtype=np.float32)
|
||||
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
|
||||
raw_audio = raw_audio.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
raw_audio = [np.asarray(raw_audio).T]
|
||||
|
||||
# verify inputs are valid
|
||||
for idx, example in enumerate(raw_audio):
|
||||
if example.ndim > 2:
|
||||
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
|
||||
if self.feature_size == 1 and example.ndim != 1:
|
||||
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
|
||||
if self.feature_size == 2:
|
||||
raise ValueError("Stereo audio isn't supported for now")
|
||||
|
||||
input_values = BatchFeature({"input_values": raw_audio})
|
||||
|
||||
# normal padding on batch
|
||||
padded_inputs = self.pad(
|
||||
input_values,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
return_attention_mask=False,
|
||||
pad_to_multiple_of=self.hop_length,
|
||||
)
|
||||
|
||||
if padding:
|
||||
padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
|
||||
|
||||
input_values = []
|
||||
for example in padded_inputs.pop("input_values"):
|
||||
if self.feature_size == 1:
|
||||
example = example[..., None]
|
||||
input_values.append(example.T)
|
||||
|
||||
padded_inputs["input_values"] = input_values
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
717
src/transformers/models/dac/modeling_dac.py
Normal file
717
src/transformers/models/dac/modeling_dac.py
Normal file
@ -0,0 +1,717 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Transformers DAC model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_dac import DacConfig
|
||||
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "DacConfig"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DacOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loss (`torch.Tensor`):
|
||||
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
|
||||
audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
|
||||
Reconstructed audio data.
|
||||
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
||||
Quantized continuous representation of input.
|
||||
audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
||||
Codebook indices for each codebook (quantized discrete representation of input).
|
||||
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
||||
Projected latents (continuous representation of input before quantization).
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor = None
|
||||
audio_values: torch.FloatTensor = None
|
||||
quantized_representation: torch.FloatTensor = None
|
||||
audio_codes: torch.LongTensor = None
|
||||
projected_latents: torch.FloatTensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DacEncoderOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
loss (`torch.Tensor`):
|
||||
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
|
||||
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
|
||||
Quantized continuous representation of input.
|
||||
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
||||
Codebook indices for each codebook (quantized discrete representation of input).
|
||||
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
|
||||
Projected latents (continuous representation of input before quantization).
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor = None
|
||||
quantized_representation: torch.FloatTensor = None
|
||||
audio_codes: torch.FloatTensor = None
|
||||
projected_latents: torch.FloatTensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
|
||||
class DacDecoderOutput(ModelOutput):
|
||||
"""
|
||||
Args:
|
||||
audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
|
||||
Decoded audio values, obtained using the decoder part of Dac.
|
||||
"""
|
||||
|
||||
audio_values: torch.FloatTensor = None
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
"""
|
||||
A 1-dimensional Snake activation function module.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
|
||||
hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
|
||||
hidden_states = hidden_states.reshape(shape)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DacVectorQuantize(nn.Module):
|
||||
"""
|
||||
Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
|
||||
|
||||
Additionally uses following tricks from improved VQGAN
|
||||
(https://arxiv.org/pdf/2110.04627.pdf):
|
||||
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
|
||||
for improved codebook usage
|
||||
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
|
||||
improves training stability
|
||||
"""
|
||||
|
||||
def __init__(self, config: DacConfig):
|
||||
super().__init__()
|
||||
|
||||
self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
|
||||
self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
|
||||
self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
"""
|
||||
Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
|
||||
|
||||
Args:
|
||||
hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
|
||||
Input tensor.
|
||||
|
||||
Returns:
|
||||
quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
|
||||
Quantized continuous representation of input.
|
||||
commitment_loss (`torch.FloatTensor`of shape `(1)`):
|
||||
Commitment loss to train encoder to predict vectors closer to codebook entries.
|
||||
codebook_loss (`torch.FloatTensor`of shape `(1)`):
|
||||
Codebook loss to update the codebook.
|
||||
audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
|
||||
Codebook indices for each codebook, quantized discrete representation of input.
|
||||
projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
||||
Projected latents (continuous representation of input before quantization).
|
||||
"""
|
||||
|
||||
projected_latents = self.in_proj(hidden_state)
|
||||
quantized_representation, audio_codes = self.decode_latents(projected_latents)
|
||||
|
||||
commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
|
||||
codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
|
||||
# noop in forward pass, straight-through gradient estimator in backward pass
|
||||
quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
|
||||
quantized_representation = self.out_proj(quantized_representation)
|
||||
|
||||
return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
|
||||
|
||||
def decode_latents(self, hidden_states):
|
||||
batch_size, hidden_dim, sequence_length = hidden_states.shape
|
||||
encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
|
||||
codebook = self.codebook.weight # codebook: (N x D)
|
||||
|
||||
# L2 normalize encodings and codebook (ViT-VQGAN)
|
||||
encodings = F.normalize(encodings)
|
||||
codebook = F.normalize(codebook)
|
||||
|
||||
# Compute euclidean distance with codebook
|
||||
l2_norm = encodings.pow(2).sum(1, keepdim=True)
|
||||
dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
|
||||
|
||||
indices = dist.max(1)[1]
|
||||
indices = indices.reshape(hidden_states.size(0), -1)
|
||||
quantized_representation = self.codebook(indices).transpose(1, 2)
|
||||
return quantized_representation, indices
|
||||
|
||||
|
||||
class DacResidualUnit(nn.Module):
|
||||
"""
|
||||
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
|
||||
self.snake1 = Snake1d(dimension)
|
||||
self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
|
||||
self.snake2 = Snake1d(dimension)
|
||||
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
"""
|
||||
Forward pass through the residual unit.
|
||||
|
||||
Args:
|
||||
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
|
||||
Input tensor .
|
||||
|
||||
Returns:
|
||||
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
|
||||
Input tensor after passing through the residual unit.
|
||||
"""
|
||||
output_tensor = hidden_state
|
||||
output_tensor = self.conv1(self.snake1(output_tensor))
|
||||
output_tensor = self.conv2(self.snake2(output_tensor))
|
||||
|
||||
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
|
||||
if padding > 0:
|
||||
hidden_state = hidden_state[..., padding:-padding]
|
||||
output_tensor = hidden_state + output_tensor
|
||||
return output_tensor
|
||||
|
||||
|
||||
class DacEncoderBlock(nn.Module):
|
||||
"""Encoder block used in DAC encoder."""
|
||||
|
||||
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
|
||||
super().__init__()
|
||||
|
||||
dimension = config.encoder_hidden_size * 2**stride_index
|
||||
self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
|
||||
self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
|
||||
self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
|
||||
self.snake1 = Snake1d(dimension // 2)
|
||||
self.conv1 = nn.Conv1d(
|
||||
dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
|
||||
)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.res_unit1(hidden_state)
|
||||
hidden_state = self.res_unit2(hidden_state)
|
||||
hidden_state = self.snake1(self.res_unit3(hidden_state))
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class DacDecoderBlock(nn.Module):
|
||||
"""Decoder block used in DAC decoder."""
|
||||
|
||||
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
|
||||
super().__init__()
|
||||
|
||||
input_dim = config.decoder_hidden_size // 2**stride_index
|
||||
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
|
||||
self.snake1 = Snake1d(input_dim)
|
||||
self.conv_t1 = nn.ConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
)
|
||||
|
||||
self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
|
||||
self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
|
||||
self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv_t1(hidden_state)
|
||||
hidden_state = self.res_unit1(hidden_state)
|
||||
hidden_state = self.res_unit2(hidden_state)
|
||||
hidden_state = self.res_unit3(hidden_state)
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class DacResidualVectorQuantize(nn.Module):
|
||||
"""
|
||||
ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312)
|
||||
"""
|
||||
|
||||
def __init__(self, config: DacConfig):
|
||||
super().__init__()
|
||||
|
||||
n_codebooks = config.n_codebooks
|
||||
quantizer_dropout = config.quantizer_dropout
|
||||
|
||||
self.n_codebooks = n_codebooks
|
||||
|
||||
self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
|
||||
def forward(self, hidden_state, n_quantizers: int = None):
|
||||
"""
|
||||
Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
|
||||
Args:
|
||||
hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
||||
Input tensor to be quantized.
|
||||
n_quantizers (`int`, *optional*):
|
||||
Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
|
||||
this argument is ignored during training, and a random number of quantizers is used.
|
||||
|
||||
Returns:
|
||||
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
||||
Quantized continuous representation of input.
|
||||
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
||||
Codebook indices for each codebook (quantized discrete representation of input).
|
||||
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
|
||||
Projected latents (continuous representation of input before quantization).
|
||||
commitment_loss (`torch.Tensor` of shape `(1)`):
|
||||
Commitment loss to train the encoder to predict vectors closer to codebook entries.
|
||||
codebook_loss (`torch.Tensor` of shape `(1)`):
|
||||
Codebook loss to update the codebook.
|
||||
"""
|
||||
|
||||
quantized_representation = 0
|
||||
residual = hidden_state
|
||||
commitment_loss = 0
|
||||
codebook_loss = 0
|
||||
|
||||
audio_codes = []
|
||||
projected_latents = []
|
||||
|
||||
n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
|
||||
if self.training:
|
||||
n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
|
||||
dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
|
||||
n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
|
||||
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
||||
n_quantizers = n_quantizers.to(hidden_state.device)
|
||||
|
||||
for i, quantizer in enumerate(self.quantizers):
|
||||
if self.training is False and i >= n_quantizers:
|
||||
break
|
||||
|
||||
quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
|
||||
residual
|
||||
)
|
||||
|
||||
# Create mask to apply quantizer dropout
|
||||
mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
|
||||
quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
|
||||
residual = residual - quantized_representation_i
|
||||
|
||||
# Sum losses
|
||||
commitment_loss += commitment_loss_i * mask
|
||||
codebook_loss += codebook_loss_i * mask
|
||||
|
||||
audio_codes.append(indices_i)
|
||||
projected_latents.append(projected_latents_i)
|
||||
|
||||
audio_codes = torch.stack(audio_codes, dim=1)
|
||||
projected_latents = torch.cat(projected_latents, dim=1)
|
||||
|
||||
return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
|
||||
|
||||
def from_codes(self, audio_codes: torch.Tensor):
|
||||
"""
|
||||
Reconstructs the continuous representation from quantized codes.
|
||||
|
||||
Args:
|
||||
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
|
||||
Quantized discrete representation of input.
|
||||
|
||||
Returns:
|
||||
quantized_representation (`torch.Tensor`):
|
||||
Quantized continuous representation of input.
|
||||
projected_latents (`torch.Tensor`):
|
||||
List of projected latents (continuous representations of input before quantization)
|
||||
for each codebook.
|
||||
audio_codes (`torch.Tensor`):
|
||||
Codebook indices for each codebook.
|
||||
"""
|
||||
quantized_representation = 0.0
|
||||
projected_latents = []
|
||||
n_codebooks = audio_codes.shape[1]
|
||||
for i in range(n_codebooks):
|
||||
projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
|
||||
projected_latents.append(projected_latents_i)
|
||||
quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
|
||||
return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
|
||||
|
||||
def from_latents(self, latents: torch.Tensor):
|
||||
"""Reconstructs the quantized representation from unquantized latents.
|
||||
|
||||
Args:
|
||||
latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
|
||||
Continuous representation of input after projection.
|
||||
|
||||
Returns:
|
||||
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
||||
Quantized representation of the full-projected space.
|
||||
quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
|
||||
Quantized representation of the latent space (continuous representation before quantization).
|
||||
"""
|
||||
quantized_representation = 0
|
||||
quantized_latents = []
|
||||
codes = []
|
||||
codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
|
||||
dims = torch.cumsum(codebook_dims_tensor, dim=0)
|
||||
|
||||
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
|
||||
for i in range(n_codebooks):
|
||||
hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
|
||||
quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
|
||||
quantized_latents.append(quantized_latents_i)
|
||||
codes.append(codes_i)
|
||||
|
||||
quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
|
||||
quantized_representation = quantized_representation + quantized_representation_i
|
||||
|
||||
return quantized_representation, torch.cat(quantized_latents, dim=1)
|
||||
|
||||
|
||||
class DacDecoder(nn.Module):
|
||||
"""DAC Decoder"""
|
||||
|
||||
def __init__(self, config: DacConfig):
|
||||
super().__init__()
|
||||
|
||||
input_channel = config.hidden_size
|
||||
channels = config.decoder_hidden_size
|
||||
strides = config.upsampling_ratios
|
||||
|
||||
# Add first conv layer
|
||||
self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
block = []
|
||||
for stride_index, stride in enumerate(strides):
|
||||
block += [DacDecoderBlock(config, stride, stride_index)]
|
||||
|
||||
self.block = nn.ModuleList(block)
|
||||
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
|
||||
self.snake1 = Snake1d(output_dim)
|
||||
self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
|
||||
for layer in self.block:
|
||||
hidden_state = layer(hidden_state)
|
||||
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv2(hidden_state)
|
||||
hidden_state = self.tanh(hidden_state)
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class DacEncoder(nn.Module):
|
||||
"""DAC Encoder"""
|
||||
|
||||
def __init__(self, config: DacConfig):
|
||||
super().__init__()
|
||||
|
||||
strides = config.downsampling_ratios
|
||||
# Create first convolution
|
||||
self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
|
||||
|
||||
self.block = []
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride_index, stride in enumerate(strides):
|
||||
stride_index = stride_index + 1
|
||||
self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
|
||||
|
||||
self.block = nn.ModuleList(self.block)
|
||||
d_model = config.encoder_hidden_size * 2**stride_index
|
||||
self.snake1 = Snake1d(d_model)
|
||||
self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
|
||||
for module in self.block:
|
||||
hidden_state = module(hidden_state)
|
||||
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv2(hidden_state)
|
||||
|
||||
return hidden_state
|
||||
|
||||
|
||||
class DacPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
config_class = DacConfig
|
||||
base_model_prefix = "dac"
|
||||
main_input_name = "input_values"
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv1d):
|
||||
nn.init.trunc_normal_(module.weight, std=0.02)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
for layer in self.quantizer.quantizers:
|
||||
nn.utils.weight_norm(layer.in_proj)
|
||||
nn.utils.weight_norm(layer.out_proj)
|
||||
|
||||
nn.utils.weight_norm(self.encoder.conv1)
|
||||
nn.utils.weight_norm(self.encoder.conv2)
|
||||
|
||||
for layer in self.encoder.block:
|
||||
nn.utils.weight_norm(layer.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit1.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit1.conv2)
|
||||
nn.utils.weight_norm(layer.res_unit2.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit2.conv2)
|
||||
nn.utils.weight_norm(layer.res_unit3.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit3.conv2)
|
||||
|
||||
nn.utils.weight_norm(self.decoder.conv1)
|
||||
nn.utils.weight_norm(self.decoder.conv2)
|
||||
|
||||
for layer in self.decoder.block:
|
||||
nn.utils.weight_norm(layer.conv_t1)
|
||||
nn.utils.weight_norm(layer.res_unit1.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit1.conv2)
|
||||
nn.utils.weight_norm(layer.res_unit2.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit2.conv2)
|
||||
nn.utils.weight_norm(layer.res_unit3.conv1)
|
||||
nn.utils.weight_norm(layer.res_unit3.conv2)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for layer in self.quantizer.quantizers:
|
||||
nn.utils.remove_weight_norm(layer.in_proj)
|
||||
nn.utils.remove_weight_norm(layer.out_proj)
|
||||
|
||||
nn.utils.remove_weight_norm(self.encoder.conv1)
|
||||
nn.utils.remove_weight_norm(self.encoder.conv2)
|
||||
|
||||
for layer in self.encoder.block:
|
||||
nn.utils.remove_weight_norm(layer.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
|
||||
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
|
||||
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
|
||||
|
||||
nn.utils.remove_weight_norm(self.decoder.conv1)
|
||||
nn.utils.remove_weight_norm(self.decoder.conv2)
|
||||
|
||||
for layer in self.decoder.block:
|
||||
nn.utils.remove_weight_norm(layer.conv_t1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
|
||||
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
|
||||
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
|
||||
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
|
||||
|
||||
|
||||
DAC_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`DacConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
DAC_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`).
|
||||
Audio data to encode,
|
||||
n_quantizers (`int`, *optional*):
|
||||
Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The DAC (Descript Audio Codec) model.",
|
||||
DAC_START_DOCSTRING,
|
||||
)
|
||||
class DacModel(DacPreTrainedModel):
|
||||
def __init__(self, config: DacConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.encoder = DacEncoder(config)
|
||||
self.decoder = DacDecoder(config)
|
||||
|
||||
self.quantizer = DacResidualVectorQuantize(config)
|
||||
|
||||
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
|
||||
if 2**self.bits_per_codebook != self.config.codebook_size:
|
||||
raise ValueError("The codebook_size must be a power of 2.")
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def encode(
|
||||
self,
|
||||
input_values: torch.Tensor,
|
||||
n_quantizers: int = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Encode given audio data and return quantized latent codes
|
||||
|
||||
Args:
|
||||
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
|
||||
Input audio data to encode,
|
||||
n_quantizers (int, *optional*):
|
||||
Number of quantizers to use. If None, all quantizers are used. Default is None.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
quantized_representation = self.encoder(input_values)
|
||||
quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
|
||||
quantized_representation, n_quantizers
|
||||
)
|
||||
|
||||
loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
|
||||
|
||||
if not return_dict:
|
||||
return (loss, quantized_representation, audio_codes, projected_latents)
|
||||
|
||||
return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
|
||||
|
||||
@replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def decode(
|
||||
self,
|
||||
quantized_representation: Optional[torch.Tensor],
|
||||
audio_codes: Optional[torch.Tensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Args:
|
||||
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`):
|
||||
Quantized continuous representation of input.
|
||||
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
|
||||
The codebook indices for each codebook, representing the quantized discrete
|
||||
representation of the input. This parameter should be provided if you want
|
||||
to decode directly from the audio codes (it will overwrite quantized_representation).
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
if quantized_representation is None and audio_codes is None:
|
||||
raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
|
||||
if audio_codes is not None:
|
||||
quantized_representation = self.quantizer.from_codes(audio_codes)[0]
|
||||
|
||||
audio_values = self.decoder(quantized_representation).squeeze(1)
|
||||
|
||||
if not return_dict:
|
||||
return (audio_values,)
|
||||
|
||||
return DacDecoderOutput(audio_values)
|
||||
|
||||
@add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values: torch.Tensor,
|
||||
n_quantizers: int = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Returns:
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from datasets import load_dataset, Audio
|
||||
>>> from transformers import DacModel, AutoProcessor
|
||||
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
>>> model = DacModel.from_pretrained("descript/dac_16khz")
|
||||
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
|
||||
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
|
||||
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
|
||||
|
||||
>>> encoder_outputs = model.encode(inputs["input_values"])
|
||||
>>> # Get the intermediate audio codes
|
||||
>>> audio_codes = encoder_outputs.audio_codes
|
||||
>>> # Reconstruct the audio from its quantized representation
|
||||
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
|
||||
>>> # or the equivalent with a forward pass
|
||||
>>> audio_values = model(inputs["input_values"]).audio_values
|
||||
```"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
||||
length = input_values.shape[-1]
|
||||
loss, quantized_representation, audio_codes, projected_latents = self.encode(
|
||||
input_values, n_quantizers, return_dict=False
|
||||
)
|
||||
audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
|
||||
|
||||
if not return_dict:
|
||||
return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
|
||||
|
||||
return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)
|
@ -2360,6 +2360,20 @@ class CvtPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DacModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DacPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Data2VecAudioForAudioFrameClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
0
tests/models/dac/__init__.py
Normal file
0
tests/models/dac/__init__.py
Normal file
216
tests/models/dac/test_feature_extraction_dac.py
Normal file
216
tests/models/dac/test_feature_extraction_dac.py
Normal file
@ -0,0 +1,216 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the dac feature extractor."""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import DacFeatureExtractor
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
# Copied from transformers.tests.encodec.test_feature_extraction_dac.EncodecFeatureExtractionTester with Encodec->Dac
|
||||
class DacFeatureExtractionTester(unittest.TestCase):
|
||||
# Ignore copy
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
min_seq_length=400,
|
||||
max_seq_length=2000,
|
||||
feature_size=1,
|
||||
padding_value=0.0,
|
||||
sampling_rate=16000,
|
||||
hop_length=512,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.min_seq_length = min_seq_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.hop_length = hop_length
|
||||
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
||||
self.feature_size = feature_size
|
||||
self.padding_value = padding_value
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
# Ignore copy
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"feature_size": self.feature_size,
|
||||
"padding_value": self.padding_value,
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"hop_length": self.hop_length,
|
||||
}
|
||||
|
||||
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
||||
def _flatten(list_of_lists):
|
||||
return list(itertools.chain(*list_of_lists))
|
||||
|
||||
if equal_length:
|
||||
audio_inputs = floats_list((self.batch_size, self.max_seq_length))
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
audio_inputs = [
|
||||
_flatten(floats_list((x, self.feature_size)))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
]
|
||||
|
||||
if numpify:
|
||||
audio_inputs = [np.asarray(x) for x in audio_inputs]
|
||||
|
||||
return audio_inputs
|
||||
|
||||
|
||||
@require_torch
|
||||
# Copied from transformers.tests.encodec.test_feature_extraction_dac.EnCodecFeatureExtractionTest with Encodec->Dac
|
||||
class DacFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = DacFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = DacFeatureExtractionTester(self)
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
|
||||
|
||||
# Test not batched input
|
||||
encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values
|
||||
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||
|
||||
# Test batched
|
||||
encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_double_precision_pad(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
np_audio_inputs = np.random.rand(100).astype(np.float64)
|
||||
py_audio_inputs = np_audio_inputs.tolist()
|
||||
|
||||
for inputs in [py_audio_inputs, np_audio_inputs]:
|
||||
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
|
||||
self.assertTrue(np_processed.input_values.dtype == np.float32)
|
||||
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
|
||||
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in audio_samples]
|
||||
|
||||
def test_integration(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_VALUES = torch.tensor(
|
||||
[ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03,
|
||||
1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04,
|
||||
9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03,
|
||||
1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04,
|
||||
1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03,
|
||||
7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03,
|
||||
8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04,
|
||||
1.0986328e-03, 2.1057129e-03]
|
||||
)
|
||||
# fmt: on
|
||||
input_audio = self._load_datasamples(1)
|
||||
feature_extractor = DacFeatureExtractor()
|
||||
input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"]
|
||||
self.assertEqual(input_values.shape, (1, 1, 93696))
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
|
||||
audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32)
|
||||
self.assertTrue(torch.allclose(input_values[0, 0, -46:-16], audio_input_end, atol=1e-4))
|
||||
|
||||
# Ignore copy
|
||||
@unittest.skip("The DAC model doesn't support stereo logic")
|
||||
def test_integration_stereo(self):
|
||||
pass
|
||||
|
||||
# Ignore copy
|
||||
def test_truncation_and_padding(self):
|
||||
input_audio = self._load_datasamples(2)
|
||||
# would be easier if the stride was like
|
||||
feature_extractor = DacFeatureExtractor()
|
||||
|
||||
# pad and trunc raise an error ?
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"^Both padding and truncation were set. Make sure you only set one.$",
|
||||
):
|
||||
truncated_outputs = feature_extractor(
|
||||
input_audio, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_values
|
||||
|
||||
# force truncate to max_length
|
||||
truncated_outputs = feature_extractor(
|
||||
input_audio, truncation=True, max_length=48000, return_tensors="pt"
|
||||
).input_values
|
||||
self.assertEqual(truncated_outputs.shape, (2, 1, 48128))
|
||||
|
||||
# pad:
|
||||
padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values
|
||||
self.assertEqual(padded_outputs.shape, (2, 1, 93696))
|
||||
|
||||
# force pad to max length
|
||||
truncated_outputs = feature_extractor(
|
||||
input_audio, padding="max_length", max_length=100000, return_tensors="pt"
|
||||
).input_values
|
||||
self.assertEqual(truncated_outputs.shape, (2, 1, 100352))
|
||||
|
||||
# force no pad
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$",
|
||||
):
|
||||
truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values
|
||||
|
||||
truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values
|
||||
self.assertEqual(truncated_outputs.shape, (1, 1, 93680))
|
749
tests/models/dac/test_modeling_dac.py
Normal file
749
tests/models/dac/test_modeling_dac.py
Normal file
@ -0,0 +1,749 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Dac model."""
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import Audio, load_dataset
|
||||
|
||||
from transformers import AutoProcessor, DacConfig, DacModel
|
||||
from transformers.testing_utils import is_torch_available, require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_torch
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.EncodecModelTester with Encodec->Dac
|
||||
class DacModelTester:
|
||||
# Ignore copy
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
num_channels=1,
|
||||
is_training=False,
|
||||
intermediate_size=1024,
|
||||
encoder_hidden_size=16,
|
||||
downsampling_ratios=[2, 4, 4],
|
||||
decoder_hidden_size=16,
|
||||
n_codebooks=6,
|
||||
codebook_size=512,
|
||||
codebook_dim=4,
|
||||
quantizer_dropout=0.0,
|
||||
commitment_loss_weight=0.25,
|
||||
codebook_loss_weight=1.0,
|
||||
sample_rate=16000,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.num_channels = num_channels
|
||||
self.is_training = is_training
|
||||
self.intermediate_size = intermediate_size
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.encoder_hidden_size = encoder_hidden_size
|
||||
self.downsampling_ratios = downsampling_ratios
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.n_codebooks = n_codebooks
|
||||
self.codebook_size = codebook_size
|
||||
self.codebook_dim = codebook_dim
|
||||
self.quantizer_dropout = quantizer_dropout
|
||||
self.commitment_loss_weight = commitment_loss_weight
|
||||
self.codebook_loss_weight = codebook_loss_weight
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
||||
config = self.get_config()
|
||||
inputs_dict = {"input_values": input_values}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_model_class(self, model_class):
|
||||
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
|
||||
config = self.get_config()
|
||||
inputs_dict = {"input_values": input_values}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
# Ignore copy
|
||||
def get_config(self):
|
||||
return DacConfig(
|
||||
encoder_hidden_size=self.encoder_hidden_size,
|
||||
downsampling_ratios=self.downsampling_ratios,
|
||||
decoder_hidden_size=self.decoder_hidden_size,
|
||||
n_codebooks=self.n_codebooks,
|
||||
codebook_size=self.codebook_size,
|
||||
codebook_dim=self.codebook_dim,
|
||||
quantizer_dropout=self.quantizer_dropout,
|
||||
commitment_loss_weight=self.commitment_loss_weight,
|
||||
codebook_loss_weight=self.codebook_loss_weight,
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = DacModel(config=config).to(torch_device).eval()
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
result = model(input_values)
|
||||
self.parent.assertEqual(result.audio_values.shape, (self.batch_size, self.intermediate_size))
|
||||
|
||||
|
||||
@require_torch
|
||||
# Copied from transformers.tests.encodec.test_modeling_encodec.EncodecModelTest with Encodec->Dac
|
||||
class DacModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DacModel,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_resize_embeddings = False
|
||||
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
|
||||
input_name = "input_values"
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
# model does not have attention and does not support returning hidden states
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
if "output_attentions" in inputs_dict:
|
||||
inputs_dict.pop("output_attentions")
|
||||
if "output_hidden_states" in inputs_dict:
|
||||
inputs_dict.pop("output_hidden_states")
|
||||
return inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DacModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=DacConfig, hidden_size=37, common_properties=[], has_text_modality=False
|
||||
)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
# Ignore copy
|
||||
expected_arg_names = ["input_values", "n_quantizers", "return_dict"]
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have `inputs_embeds` logics")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have `inputs_embeds` logics")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `attention` logic")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `attention` logic")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `hidden_states` logic")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||
if not self.test_torchscript:
|
||||
return
|
||||
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
configs_no_init.torchscript = True
|
||||
configs_no_init.return_dict = False
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
main_input_name = model_class.main_input_name
|
||||
|
||||
try:
|
||||
main_input = inputs[main_input_name]
|
||||
model(main_input)
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
except RuntimeError:
|
||||
self.fail("Couldn't trace module.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
||||
|
||||
try:
|
||||
torch.jit.save(traced_model, pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't save module.")
|
||||
|
||||
try:
|
||||
loaded_model = torch.jit.load(pt_file_name)
|
||||
except Exception:
|
||||
self.fail("Couldn't load module.")
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loaded_model.to(torch_device)
|
||||
loaded_model.eval()
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_model_state_dict = loaded_model.state_dict()
|
||||
|
||||
non_persistent_buffers = {}
|
||||
for key in loaded_model_state_dict.keys():
|
||||
if key not in model_state_dict.keys():
|
||||
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
||||
|
||||
loaded_model_state_dict = {
|
||||
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
||||
}
|
||||
|
||||
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
model_buffers = list(model.buffers())
|
||||
for non_persistent_buffer in non_persistent_buffers.values():
|
||||
found_buffer = False
|
||||
for i, model_buffer in enumerate(model_buffers):
|
||||
if torch.equal(non_persistent_buffer, model_buffer):
|
||||
found_buffer = True
|
||||
break
|
||||
|
||||
self.assertTrue(found_buffer)
|
||||
model_buffers.pop(i)
|
||||
|
||||
models_equal = True
|
||||
for layer_name, p1 in model_state_dict.items():
|
||||
if layer_name in loaded_model_state_dict:
|
||||
p2 = loaded_model_state_dict[layer_name]
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
|
||||
# (Even with this call, there are still memory leak by ~0.04MB)
|
||||
self.clear_torch_jit_class_registry()
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `attention` logic")
|
||||
def test_attention_outputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `hidden_states` logic")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("No support for low_cpu_mem_usage=True.")
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("No support for low_cpu_mem_usage=True.")
|
||||
def test_save_load_low_cpu_mem_usage_checkpoints(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("No support for low_cpu_mem_usage=True.")
|
||||
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
|
||||
pass
|
||||
|
||||
def test_determinism(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_determinism(first, second):
|
||||
# outputs are not tensors but list (since each sequence don't have the same frame_length)
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
if isinstance(first, tuple) and isinstance(second, tuple):
|
||||
for tensor1, tensor2 in zip(first, second):
|
||||
check_determinism(tensor1, tensor2)
|
||||
else:
|
||||
check_determinism(first, second)
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(
|
||||
tuple_object.values(), dict_object.values()
|
||||
):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
# Ignore copy
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = ["conv", "in_proj", "out_proj", "codebook"]
|
||||
if param.requires_grad:
|
||||
if any(x in name for x in uniform_init_parms):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_identity_shortcut(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.use_conv_shortcut = False
|
||||
self.model_tester.create_and_check_model_forward(config, inputs_dict)
|
||||
|
||||
|
||||
def normalize(arr):
|
||||
norm = np.linalg.norm(arr)
|
||||
normalized_arr = arr / norm
|
||||
return normalized_arr
|
||||
|
||||
|
||||
def compute_rmse(arr1, arr2):
|
||||
arr1_normalized = normalize(arr1)
|
||||
arr2_normalized = normalize(arr2)
|
||||
return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
class DacIntegrationTest(unittest.TestCase):
|
||||
def test_integration_16khz(self):
|
||||
expected_rmse = 0.004
|
||||
|
||||
expected_encoder_sums_dict = {
|
||||
"loss": 24.8596,
|
||||
"quantized_representation": -0.0745,
|
||||
"audio_codes": 504.0948,
|
||||
"projected_latents": 0.0682,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_name = "dac_16khz"
|
||||
|
||||
model_id = "descript/{}".format(model_name)
|
||||
model = DacModel.from_pretrained(model_id, force_download=True).to(torch_device).eval()
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[0]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode(inputs["input_values"])
|
||||
|
||||
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
|
||||
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
|
||||
|
||||
arr_cut = arr[0, :max_length].copy()
|
||||
arr_enc_dec_cut = arr_enc_dec[:max_length].copy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
|
||||
self.assertTrue(rmse < expected_rmse)
|
||||
|
||||
def test_integration_24khz(self):
|
||||
expected_rmse = 0.0039
|
||||
|
||||
expected_encoder_output_dict = {
|
||||
"quantized_representation": torch.tensor([0.9807, 2.8212, 5.2514, 2.7241, 1.0426]),
|
||||
"audio_codes": torch.tensor([919, 919, 234, 777, 234]),
|
||||
"projected_latents": torch.tensor([-4.7822, -5.0046, -4.5574, -5.0363, -5.4271]),
|
||||
}
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_name = "dac_24khz"
|
||||
|
||||
model_id = "descript/{}".format(model_name)
|
||||
model = DacModel.from_pretrained(model_id, force_download=True).to(torch_device).eval()
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[0]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode(inputs["input_values"])
|
||||
|
||||
expected_quantized_representation = encoder_outputs["quantized_representation"][0, 0, :5].cpu()
|
||||
expected_audio_codes = encoder_outputs["audio_codes"][0, 0, :5].cpu()
|
||||
expected_projected_latents = encoder_outputs["projected_latents"][0, 0, :5].cpu()
|
||||
|
||||
# make sure values are correct for audios slices
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
expected_quantized_representation,
|
||||
expected_encoder_output_dict["quantized_representation"],
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(expected_audio_codes, expected_encoder_output_dict["audio_codes"], atol=1e-3)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
expected_projected_latents, expected_encoder_output_dict["projected_latents"], atol=1e-3
|
||||
)
|
||||
)
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
|
||||
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
|
||||
|
||||
arr_cut = arr[0, :max_length].copy()
|
||||
arr_enc_dec_cut = arr_enc_dec[:max_length].copy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
|
||||
self.assertTrue(rmse < expected_rmse)
|
||||
|
||||
def test_integration_44khz(self):
|
||||
expected_rmse = 0.002
|
||||
|
||||
expected_encoder_sums_dict = {
|
||||
"loss": 34.3612,
|
||||
"quantized_representation": 0.0078,
|
||||
"audio_codes": 509.6812,
|
||||
"projected_latents": -0.1054,
|
||||
}
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_name = "dac_44khz"
|
||||
|
||||
model_id = "descript/{}".format(model_name)
|
||||
model = DacModel.from_pretrained(model_id).to(torch_device).eval()
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
audio_sample = librispeech_dummy[0]["audio"]["array"]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_sample,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode(inputs["input_values"])
|
||||
|
||||
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
arr = inputs["input_values"][0].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
|
||||
|
||||
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
|
||||
|
||||
arr_cut = arr[0, :max_length].copy()
|
||||
arr_enc_dec_cut = arr_enc_dec[:max_length].copy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
|
||||
self.assertTrue(rmse < expected_rmse)
|
||||
|
||||
def test_integration_batch_16khz(self):
|
||||
expected_rmse = 0.002
|
||||
|
||||
expected_encoder_sums_dict = {
|
||||
"loss": 20.3913,
|
||||
"quantized_representation": -0.0538,
|
||||
"audio_codes": 487.8470,
|
||||
"projected_latents": 0.0237,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_name = "dac_16khz"
|
||||
|
||||
model_id = "descript/{}".format(model_name)
|
||||
model = DacModel.from_pretrained(model_id).to(torch_device)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
|
||||
audio_samples = [np.array([audio_sample["array"]])[0] for audio_sample in librispeech_dummy[-2:]["audio"]]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_samples,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
truncation=False,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode(inputs["input_values"])
|
||||
|
||||
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
arr = inputs["input_values"].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec.cpu().numpy()
|
||||
|
||||
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
|
||||
|
||||
arr_cut = arr[:, 0, :max_length].copy()
|
||||
arr_enc_dec_cut = arr_enc_dec[:, :max_length].copy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
|
||||
self.assertTrue(rmse < expected_rmse)
|
||||
|
||||
def test_integration_batch_24khz(self):
|
||||
expected_rmse = 0.002
|
||||
|
||||
expected_encoder_sums_dict = {
|
||||
"loss": 24.2309,
|
||||
"quantized_representation": 0.0520,
|
||||
"audio_codes": 510.2700,
|
||||
"projected_latents": -0.0076,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_name = "dac_24khz"
|
||||
|
||||
model_id = "descript/{}".format(model_name)
|
||||
model = DacModel.from_pretrained(model_id).to(torch_device)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
|
||||
audio_samples = [np.array([audio_sample["array"]])[0] for audio_sample in librispeech_dummy[-2:]["audio"]]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_samples,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
truncation=False,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode(inputs["input_values"])
|
||||
|
||||
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
arr = inputs["input_values"].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec.cpu().numpy()
|
||||
|
||||
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
|
||||
|
||||
arr_cut = arr[:, 0, :max_length].copy()
|
||||
arr_enc_dec_cut = arr_enc_dec[:, :max_length].copy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
|
||||
self.assertTrue(rmse < expected_rmse)
|
||||
|
||||
def test_integration_batch_44khz(self):
|
||||
expected_rmse = 0.001
|
||||
|
||||
expected_encoder_sums_dict = {
|
||||
"loss": 25.9233,
|
||||
"quantized_representation": 0.0013,
|
||||
"audio_codes": 528.5620,
|
||||
"projected_latents": -0.1194,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
model_name = "dac_44khz"
|
||||
|
||||
model_id = "descript/{}".format(model_name)
|
||||
model = DacModel.from_pretrained(model_id).to(torch_device)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
|
||||
|
||||
audio_samples = [np.array([audio_sample["array"]])[0] for audio_sample in librispeech_dummy[-2:]["audio"]]
|
||||
|
||||
inputs = processor(
|
||||
raw_audio=audio_samples,
|
||||
sampling_rate=processor.sampling_rate,
|
||||
truncation=False,
|
||||
return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_outputs = model.encode(inputs["input_values"])
|
||||
|
||||
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
|
||||
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
|
||||
|
||||
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
|
||||
input_values_dec = model.decode(quantized_representation)[0]
|
||||
input_values_enc_dec = model(inputs["input_values"])[1]
|
||||
|
||||
# make sure forward and decode gives same result
|
||||
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
|
||||
|
||||
arr = inputs["input_values"].cpu().numpy()
|
||||
arr_enc_dec = input_values_enc_dec.cpu().numpy()
|
||||
|
||||
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
|
||||
|
||||
arr_cut = arr[:, 0, :max_length].copy()
|
||||
arr_enc_dec_cut = arr_enc_dec[:, :max_length].copy()
|
||||
|
||||
# make sure audios are more or less equal
|
||||
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
|
||||
self.assertTrue(rmse < expected_rmse)
|
Loading…
Reference in New Issue
Block a user