Add Descript-Audio-Codec model (#31494)

* dac model

* original dac works

* add dac model

* dac can be instatiated

* add forward pass

* load weights

* all weights are used

* convert checkpoint script ready

* test

* add feature extractor

* up

* make style

* apply cookicutter

* fix tests

* iterate on FeatureExtractor

* nit

* update dac doc

* replace nn.Sequential with nn.ModuleList

* nit

* apply review suggestions 1/2

* Update src/transformers/models/dac/modeling_dac.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* up

* apply review suggestions 2/2

* update padding in FeatureExtractor

* apply review suggestions

* iterate on design and tests

* add integration tests

* feature extractor tests

* make style

* all tests pass

* make style

* fixup

* apply review suggestions

* fix-copies

* apply review suggestions

* apply review suggestions

* Update docs/source/en/model_doc/dac.md

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* Update docs/source/en/model_doc/dac.md

Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>

* anticipate transfer weights to descript

* up

* make style

* apply review suggestions

* update slow test values

* update slow tests

* update test values

* update with CI values

* update with vorace values

* update test with slice

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
Kamil Akesbi 2024-08-19 10:21:51 +01:00 committed by GitHub
parent 843e5e20ca
commit 8260cb311e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2401 additions and 0 deletions

View File

@ -696,6 +696,8 @@
title: Bark
- local: model_doc/clap
title: CLAP
- local: model_doc/dac
title: dac
- local: model_doc/encodec
title: EnCodec
- local: model_doc/hiera

View File

@ -105,6 +105,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
| [CvT](model_doc/cvt) | ✅ | ✅ | ❌ |
| [DAC](model_doc/dac) | ✅ | ❌ | ❌ |
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |
| [Data2VecVision](model_doc/data2vec) | ✅ | ✅ | ❌ |

View File

@ -0,0 +1,80 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# DAC
## Overview
The DAC model was proposed in [Descript Audio Codec: High-Fidelity Audio Compression with Improved RVQGAN](https://arxiv.org/abs/2306.06546) by Rithesh Kumar, Prem Seetharaman, Alejandro Luebs, Ishaan Kumar, Kundan Kumar.
The Descript Audio Codec (DAC) model is a powerful tool for compressing audio data, making it highly efficient for storage and transmission. By compressing 44.1 KHz audio into tokens at just 8kbps bandwidth, the DAC model enables high-quality audio processing while significantly reducing the data footprint. This is particularly useful in scenarios where bandwidth is limited or storage space is at a premium, such as in streaming applications, remote conferencing, and archiving large audio datasets.
The abstract from the paper is the following:
*Language models have been successfully used to model natural signals, such as images, speech, and music. A key component of these models is a high quality neural compression model that can compress high-dimensional natural signals into lower dimensional discrete tokens. To that end, we introduce a high-fidelity universal neural audio compression algorithm that achieves ~90x compression of 44.1 KHz audio into tokens at just 8kbps bandwidth. We achieve this by combining advances in high-fidelity audio generation with better vector quantization techniques from the image domain, along with improved adversarial and reconstruction losses. We compress all domains (speech, environment, music, etc.) with a single universal model, making it widely applicable to generative modeling of all audio. We compare with competing audio compression algorithms, and find our method outperforms them significantly. We provide thorough ablations for every design choice, as well as open-source code and trained model weights. We hope our work can lay the foundation for the next generation of high-fidelity audio modeling.*
This model was contributed by [Kamil Akesbi](https://huggingface.co/kamilakesbi).
The original code can be found [here](https://github.com/descriptinc/descript-audio-codec/tree/main?tab=readme-ov-file).
## Model structure
The Descript Audio Codec (DAC) model is structured into three distinct stages:
1. Encoder Model: This stage compresses the input audio, reducing its size while retaining essential information.
2. Residual Vector Quantizer (RVQ) Model: Working in tandem with the encoder, this model quantizes the latent codes of the audio, refining the compression and ensuring high-quality reconstruction.
3. Decoder Model: This final stage reconstructs the audio from its compressed form, restoring it to a state that closely resembles the original input.
## Usage example
Here is a quick example of how to encode and decode an audio using this model:
```python
>>> from datasets import load_dataset, Audio
>>> from transformers import DacModel, AutoProcessor
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> model = DacModel.from_pretrained("descript/dac_16khz")
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
>>> encoder_outputs = model.encode(inputs["input_values"])
>>> # Get the intermediate audio codes
>>> audio_codes = encoder_outputs.audio_codes
>>> # Reconstruct the audio from its quantized representation
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
>>> # or the equivalent with a forward pass
>>> audio_values = model(inputs["input_values"]).audio_values
```
## DacConfig
[[autodoc]] DacConfig
## DacFeatureExtractor
[[autodoc]] DacFeatureExtractor
- __call__
## DacModel
[[autodoc]] DacModel
- decode
- encode
- forward

View File

@ -312,6 +312,7 @@ _import_structure = {
"CTRLTokenizer",
],
"models.cvt": ["CvtConfig"],
"models.dac": ["DacConfig", "DacFeatureExtractor"],
"models.data2vec": [
"Data2VecAudioConfig",
"Data2VecTextConfig",
@ -1757,6 +1758,12 @@ else:
"CvtPreTrainedModel",
]
)
_import_structure["models.dac"].extend(
[
"DacModel",
"DacPreTrainedModel",
]
)
_import_structure["models.data2vec"].extend(
[
"Data2VecAudioForAudioFrameClassification",
@ -5026,6 +5033,10 @@ if TYPE_CHECKING:
CTRLTokenizer,
)
from .models.cvt import CvtConfig
from .models.dac import (
DacConfig,
DacFeatureExtractor,
)
from .models.data2vec import (
Data2VecAudioConfig,
Data2VecTextConfig,
@ -6450,6 +6461,10 @@ if TYPE_CHECKING:
CvtModel,
CvtPreTrainedModel,
)
from .models.dac import (
DacModel,
DacPreTrainedModel,
)
from .models.data2vec import (
Data2VecAudioForAudioFrameClassification,
Data2VecAudioForCTC,

View File

@ -59,6 +59,7 @@ from . import (
cpmant,
ctrl,
cvt,
dac,
data2vec,
dbrx,
deberta,

View File

@ -73,6 +73,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("cpmant", "CpmAntConfig"),
("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"),
("dac", "DacConfig"),
("data2vec-audio", "Data2VecAudioConfig"),
("data2vec-text", "Data2VecTextConfig"),
("data2vec-vision", "Data2VecVisionConfig"),
@ -354,6 +355,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("cpmant", "CPM-Ant"),
("ctrl", "CTRL"),
("cvt", "CvT"),
("dac", "DAC"),
("data2vec-audio", "Data2VecAudio"),
("data2vec-text", "Data2VecText"),
("data2vec-vision", "Data2VecVision"),

View File

@ -49,6 +49,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("conditional_detr", "ConditionalDetrFeatureExtractor"),
("convnext", "ConvNextFeatureExtractor"),
("cvt", "ConvNextFeatureExtractor"),
("dac", "DacFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("data2vec-vision", "BeitFeatureExtractor"),
("deformable_detr", "DeformableDetrFeatureExtractor"),

View File

@ -73,6 +73,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("cpmant", "CpmAntModel"),
("ctrl", "CTRLModel"),
("cvt", "CvtModel"),
("dac", "DacModel"),
("data2vec-audio", "Data2VecAudioModel"),
("data2vec-text", "Data2VecTextModel"),
("data2vec-vision", "Data2VecVisionModel"),

View File

@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_dac": ["DacConfig"],
"feature_extraction_dac": ["DacFeatureExtractor"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_dac"] = [
"DacModel",
"DacPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_dac import (
DacConfig,
)
from .feature_extraction_dac import DacFeatureExtractor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_dac import (
DacModel,
DacPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,111 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dac model configuration"""
import math
import numpy as np
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class DacConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of an [`DacModel`]. It is used to instantiate a
Dac model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[descript/dac_16khz](https://huggingface.co/descript/dac_16khz) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
encoder_hidden_size (`int`, *optional*, defaults to 64):
Intermediate representation dimension for the encoder.
downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
decoder_hidden_size (`int`, *optional*, defaults to 1536):
Intermediate representation dimension for the decoder.
n_codebooks (`int`, *optional*, defaults to 9):
Number of codebooks in the VQVAE.
codebook_size (`int`, *optional*, defaults to 1024):
Number of discrete codes in each codebook.
codebook_dim (`int`, *optional*, defaults to 8):
Dimension of the codebook vectors. If not defined, uses `encoder_hidden_size`.
quantizer_dropout (`bool`, *optional*, defaults to 0):
Whether to apply dropout to the quantizer.
commitment_loss_weight (float, *optional*, defaults to 0.25):
Weight of the commitment loss term in the VQVAE loss function.
codebook_loss_weight (float, *optional*, defaults to 1.0):
Weight of the codebook loss term in the VQVAE loss function.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
Example:
```python
>>> from transformers import DacModel, DacConfig
>>> # Initializing a "descript/dac_16khz" style configuration
>>> configuration = DacConfig()
>>> # Initializing a model (with random weights) from the "descript/dac_16khz" style configuration
>>> model = DacModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "dac"
def __init__(
self,
encoder_hidden_size=64,
downsampling_ratios=[2, 4, 8, 8],
decoder_hidden_size=1536,
n_codebooks=9,
codebook_size=1024,
codebook_dim=8,
quantizer_dropout=0,
commitment_loss_weight=0.25,
codebook_loss_weight=1.0,
sampling_rate=16000,
**kwargs,
):
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_hidden_size = decoder_hidden_size
self.upsampling_ratios = downsampling_ratios[::-1]
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.sampling_rate = sampling_rate
self.hidden_size = encoder_hidden_size * (2 ** len(downsampling_ratios))
self.hop_length = int(np.prod(downsampling_ratios))
self.commitment_loss_weight = commitment_loss_weight
self.codebook_loss_weight = codebook_loss_weight
super().__init__(**kwargs)
@property
def frame_rate(self) -> int:
hop_length = np.prod(self.upsampling_ratios)
return math.ceil(self.sampling_rate / hop_length)

View File

@ -0,0 +1,261 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import fnmatch
import re
import torch
from transformers import (
DacConfig,
DacFeatureExtractor,
DacModel,
logging,
)
# checkpoints downloaded using:
# pip install descript-audio-codec
# python3 -m dac download # downloads the default 44kHz variant
# python3 -m dac download --model_type 44khz # downloads the 44kHz variant
# python3 -m dac download --model_type 24khz # downloads the 24kHz variant
# python3 -m dac download --model_type 16khz # downloads the 16kHz variant
# More informations: https://github.com/descriptinc/descript-audio-codec/tree/main
logging.set_verbosity_info()
logger = logging.get_logger("transformers.models.dac")
def match_pattern(string, pattern):
# Split the pattern into parts
pattern_parts = pattern.split(".")
string_parts = string.split(".")
pattern_block_count = string_block_count = 0
for part in pattern_parts:
if part.startswith("block"):
pattern_block_count += 1
for part in string_parts:
if part.startswith("block"):
string_block_count += 1
return fnmatch.fnmatch(string, pattern) and string_block_count == pattern_block_count
TOP_LEVEL_KEYS = []
IGNORE_KEYS = []
MAPPING_ENCODER = {
"encoder.block.0": ["encoder.conv1"],
"encoder.block.5": ["encoder.snake1"],
"encoder.block.6": ["encoder.conv2"],
"encoder.block.*.block.*.block.0".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake1"],
"encoder.block.*.block.*.block.1".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv1"],
"encoder.block.*.block.*.block.2".replace("*", r"\d+"): ["encoder.block", "res_unit", "snake2"],
"encoder.block.*.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "res_unit", "conv2"],
"encoder.block.*.block.3".replace("*", r"\d+"): ["encoder.block", "snake1"],
"encoder.block.*.block.4".replace("*", r"\d+"): ["encoder.block", "conv1"],
}
MAPPING_QUANTIZER = {
"quantizer.quantizers.*": ["quantizer.quantizers.*"],
}
MAPPING_DECODER = {
"decoder.model.0": ["decoder.conv1"],
"decoder.model.5": ["decoder.snake1"],
"decoder.model.6": ["decoder.conv2"],
"decoder.model.*.block.0".replace("*", r"\d+"): ["decoder.block", "snake1"],
"decoder.model.*.block.1".replace("*", r"\d+"): ["decoder.block", "conv_t1"],
"decoder.model.*.block.*.block.0".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake1"],
"decoder.model.*.block.*.block.1".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv1"],
"decoder.model.*.block.*.block.2".replace("*", r"\d+"): ["decoder.block", "res_unit", "snake2"],
"decoder.model.*.block.*.block.3".replace("*", r"\d+"): ["decoder.block", "res_unit", "conv2"],
}
MAPPING = {
**MAPPING_ENCODER,
**MAPPING_QUANTIZER,
**MAPPING_DECODER,
}
def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape
if hf_shape != value.shape:
raise ValueError(
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
f" {value.shape} for {full_name}"
)
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "alpha":
hf_pointer.alpha.data = value
logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
def should_ignore(name, ignore_keys):
for key in ignore_keys:
if key.endswith(".*"):
if name.startswith(key[:-1]):
return True
elif ".*." in key:
prefix, suffix = key.split(".*.")
if prefix in name and suffix in name:
return True
elif key in name:
return True
return False
def recursively_load_weights(orig_dict, hf_model, model_name):
unused_weights = []
if model_name not in ["dac_16khz", "dac_24khz", "dac_44khz"]:
raise ValueError(f"Unsupported model: {model_name}")
for name, value in orig_dict.items():
is_used = False
for key, mapped_key in MAPPING.items():
regex = re.compile(key)
if regex.search(name):
if len(mapped_key) == 1:
if mapped_key[0][0] == "q":
mapped_key = ".".join(name.split(".")[:-1])
else:
mapped_key = mapped_key[0]
elif len(mapped_key) == 3:
integers = re.findall(r"\b\d+\b", name)
if mapped_key[0][0] == "d":
mapped_key = "{}.{}.{}{}.{}".format(
mapped_key[0],
str(int(integers[0]) - 1),
mapped_key[1],
str(int(integers[1]) - 1),
mapped_key[2],
)
else:
mapped_key = "{}.{}.{}{}.{}".format(
mapped_key[0],
str(int(integers[0]) - 1),
mapped_key[1],
str(int(integers[1]) + 1),
mapped_key[2],
)
elif len(mapped_key) == 2:
integers = re.findall(r"\b\d+\b", name)
mapped_key = "{}.{}.{}".format(mapped_key[0], str(int(integers[0]) - 1), mapped_key[1])
is_used = True
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "bias" in name:
weight_type = "bias"
elif "alpha" in name:
weight_type = "alpha"
elif "weight" in name:
weight_type = "weight"
set_recursively(hf_model, mapped_key, value, name, weight_type)
if not is_used:
unused_weights.append(name)
print(list(set(unused_weights)))
logger.warning(f"Unused weights: {unused_weights}")
@torch.no_grad()
def convert_checkpoint(
model_name,
checkpoint_path,
pytorch_dump_folder_path,
sample_rate=16000,
repo_id=None,
):
model_dict = torch.load(checkpoint_path, "cpu")
config = DacConfig()
metadata = model_dict["metadata"]["kwargs"]
config.encoder_hidden_size = metadata["encoder_dim"]
config.downsampling_ratios = metadata["encoder_rates"]
config.codebook_size = metadata["codebook_size"]
config.n_codebooks = metadata["n_codebooks"]
config.codebook_dim = metadata["codebook_dim"]
config.decoder_hidden_size = metadata["decoder_dim"]
config.upsampling_ratios = metadata["decoder_rates"]
config.quantizer_dropout = float(metadata["quantizer_dropout"])
config.sampling_rate = sample_rate
model = DacModel(config)
feature_extractor = DacFeatureExtractor()
feature_extractor.sampling_rate = sample_rate
original_checkpoint = model_dict["state_dict"]
model.apply_weight_norm()
recursively_load_weights(original_checkpoint, model, model_name)
model.remove_weight_norm()
model.save_pretrained(pytorch_dump_folder_path)
if repo_id:
print("Pushing to the hub...")
feature_extractor.push_to_hub(repo_id)
model.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
default="dac_44khz",
type=str,
help="The model to convert. Should be one of 'dac_16khz', 'dac_24khz', 'dac_44khz'.",
)
parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
parser.add_argument(
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
)
parser.add_argument("--sample_rate", default=None, type=str, help="Sample rate used by DacFeatureExtractor")
args = parser.parse_args()
convert_checkpoint(
args.model, args.checkpoint_path, args.pytorch_dump_folder_path, args.sample_rate, args.push_to_hub
)

View File

@ -0,0 +1,170 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for DAC"""
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
logger = logging.get_logger(__name__)
class DacFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs an Dac feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
feature_size (`int`, *optional*, defaults to 1):
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used for padding.
hop_length (`int`, *optional*, defaults to 512):
Overlap length between successive windows.
"""
model_input_names = ["input_values", "n_quantizers"]
def __init__(
self,
feature_size: int = 1,
sampling_rate: int = 16000,
padding_value: float = 0.0,
hop_length: int = 512,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.hop_length = hop_length
def __call__(
self,
raw_audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s).
Args:
raw_audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
(`feature_size = 2`).
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
truncation (`bool`, *optional*, defaults to `False`):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
if padding and truncation:
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
elif padding is None:
# by default let's pad the inputs
padding = True
is_batched = bool(
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
elif not is_batched and not isinstance(raw_audio, np.ndarray):
raw_audio = np.asarray(raw_audio, dtype=np.float32)
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
raw_audio = raw_audio.astype(np.float32)
# always return batch
if not is_batched:
raw_audio = [np.asarray(raw_audio).T]
# verify inputs are valid
for idx, example in enumerate(raw_audio):
if example.ndim > 2:
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
if self.feature_size == 1 and example.ndim != 1:
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
if self.feature_size == 2:
raise ValueError("Stereo audio isn't supported for now")
input_values = BatchFeature({"input_values": raw_audio})
# normal padding on batch
padded_inputs = self.pad(
input_values,
max_length=max_length,
truncation=truncation,
padding=padding,
return_attention_mask=False,
pad_to_multiple_of=self.hop_length,
)
if padding:
padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
input_values = []
for example in padded_inputs.pop("input_values"):
if self.feature_size == 1:
example = example[..., None]
input_values.append(example.T)
padded_inputs["input_values"] = input_values
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs

View File

@ -0,0 +1,717 @@
# coding=utf-8
# Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformers DAC model."""
import math
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .configuration_dac import DacConfig
# General docstring
_CONFIG_FOR_DOC = "DacConfig"
@dataclass
class DacOutput(ModelOutput):
"""
Args:
loss (`torch.Tensor`):
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
Reconstructed audio data.
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
Codebook indices for each codebook (quantized discrete representation of input).
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
Projected latents (continuous representation of input before quantization).
"""
loss: torch.FloatTensor = None
audio_values: torch.FloatTensor = None
quantized_representation: torch.FloatTensor = None
audio_codes: torch.LongTensor = None
projected_latents: torch.FloatTensor = None
@dataclass
class DacEncoderOutput(ModelOutput):
"""
Args:
loss (`torch.Tensor`):
Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
Codebook indices for each codebook (quantized discrete representation of input).
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
Projected latents (continuous representation of input before quantization).
"""
loss: torch.FloatTensor = None
quantized_representation: torch.FloatTensor = None
audio_codes: torch.FloatTensor = None
projected_latents: torch.FloatTensor = None
@dataclass
# Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
class DacDecoderOutput(ModelOutput):
"""
Args:
audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
Decoded audio values, obtained using the decoder part of Dac.
"""
audio_values: torch.FloatTensor = None
class Snake1d(nn.Module):
"""
A 1-dimensional Snake activation function module.
"""
def __init__(self, hidden_dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
def forward(self, hidden_states):
shape = hidden_states.shape
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
hidden_states = hidden_states.reshape(shape)
return hidden_states
class DacVectorQuantize(nn.Module):
"""
Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
Additionally uses following tricks from improved VQGAN
(https://arxiv.org/pdf/2110.04627.pdf):
1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
for improved codebook usage
2. l2-normalized codes: Converts euclidean distance to cosine similarity which
improves training stability
"""
def __init__(self, config: DacConfig):
super().__init__()
self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
def forward(self, hidden_state):
"""
Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
Args:
hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
Input tensor.
Returns:
quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
commitment_loss (`torch.FloatTensor`of shape `(1)`):
Commitment loss to train encoder to predict vectors closer to codebook entries.
codebook_loss (`torch.FloatTensor`of shape `(1)`):
Codebook loss to update the codebook.
audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
Codebook indices for each codebook, quantized discrete representation of input.
projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
Projected latents (continuous representation of input before quantization).
"""
projected_latents = self.in_proj(hidden_state)
quantized_representation, audio_codes = self.decode_latents(projected_latents)
commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
# noop in forward pass, straight-through gradient estimator in backward pass
quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
quantized_representation = self.out_proj(quantized_representation)
return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
def decode_latents(self, hidden_states):
batch_size, hidden_dim, sequence_length = hidden_states.shape
encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
codebook = self.codebook.weight # codebook: (N x D)
# L2 normalize encodings and codebook (ViT-VQGAN)
encodings = F.normalize(encodings)
codebook = F.normalize(codebook)
# Compute euclidean distance with codebook
l2_norm = encodings.pow(2).sum(1, keepdim=True)
dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
indices = dist.max(1)[1]
indices = indices.reshape(hidden_states.size(0), -1)
quantized_representation = self.codebook(indices).transpose(1, 2)
return quantized_representation, indices
class DacResidualUnit(nn.Module):
"""
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
"""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
self.snake2 = Snake1d(dimension)
self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
def forward(self, hidden_state):
"""
Forward pass through the residual unit.
Args:
hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
Input tensor .
Returns:
output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
Input tensor after passing through the residual unit.
"""
output_tensor = hidden_state
output_tensor = self.conv1(self.snake1(output_tensor))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
output_tensor = hidden_state + output_tensor
return output_tensor
class DacEncoderBlock(nn.Module):
"""Encoder block used in DAC encoder."""
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
super().__init__()
dimension = config.encoder_hidden_size * 2**stride_index
self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
self.snake1 = Snake1d(dimension // 2)
self.conv1 = nn.Conv1d(
dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
)
def forward(self, hidden_state):
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class DacDecoderBlock(nn.Module):
"""Decoder block used in DAC decoder."""
def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
super().__init__()
input_dim = config.decoder_hidden_size // 2**stride_index
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
self.snake1 = Snake1d(input_dim)
self.conv_t1 = nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
class DacResidualVectorQuantize(nn.Module):
"""
ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312)
"""
def __init__(self, config: DacConfig):
super().__init__()
n_codebooks = config.n_codebooks
quantizer_dropout = config.quantizer_dropout
self.n_codebooks = n_codebooks
self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
self.quantizer_dropout = quantizer_dropout
def forward(self, hidden_state, n_quantizers: int = None):
"""
Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
Args:
hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Input tensor to be quantized.
n_quantizers (`int`, *optional*):
Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
this argument is ignored during training, and a random number of quantizers is used.
Returns:
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
Codebook indices for each codebook (quantized discrete representation of input).
projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
Projected latents (continuous representation of input before quantization).
commitment_loss (`torch.Tensor` of shape `(1)`):
Commitment loss to train the encoder to predict vectors closer to codebook entries.
codebook_loss (`torch.Tensor` of shape `(1)`):
Codebook loss to update the codebook.
"""
quantized_representation = 0
residual = hidden_state
commitment_loss = 0
codebook_loss = 0
audio_codes = []
projected_latents = []
n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
if self.training:
n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
n_quantizers[:n_dropout] = dropout[:n_dropout]
n_quantizers = n_quantizers.to(hidden_state.device)
for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break
quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
residual
)
# Create mask to apply quantizer dropout
mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
residual = residual - quantized_representation_i
# Sum losses
commitment_loss += commitment_loss_i * mask
codebook_loss += codebook_loss_i * mask
audio_codes.append(indices_i)
projected_latents.append(projected_latents_i)
audio_codes = torch.stack(audio_codes, dim=1)
projected_latents = torch.cat(projected_latents, dim=1)
return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
def from_codes(self, audio_codes: torch.Tensor):
"""
Reconstructs the continuous representation from quantized codes.
Args:
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
Quantized discrete representation of input.
Returns:
quantized_representation (`torch.Tensor`):
Quantized continuous representation of input.
projected_latents (`torch.Tensor`):
List of projected latents (continuous representations of input before quantization)
for each codebook.
audio_codes (`torch.Tensor`):
Codebook indices for each codebook.
"""
quantized_representation = 0.0
projected_latents = []
n_codebooks = audio_codes.shape[1]
for i in range(n_codebooks):
projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
projected_latents.append(projected_latents_i)
quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
def from_latents(self, latents: torch.Tensor):
"""Reconstructs the quantized representation from unquantized latents.
Args:
latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
Continuous representation of input after projection.
Returns:
quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized representation of the full-projected space.
quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
Quantized representation of the latent space (continuous representation before quantization).
"""
quantized_representation = 0
quantized_latents = []
codes = []
codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
dims = torch.cumsum(codebook_dims_tensor, dim=0)
n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
for i in range(n_codebooks):
hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
quantized_latents.append(quantized_latents_i)
codes.append(codes_i)
quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
quantized_representation = quantized_representation + quantized_representation_i
return quantized_representation, torch.cat(quantized_latents, dim=1)
class DacDecoder(nn.Module):
"""DAC Decoder"""
def __init__(self, config: DacConfig):
super().__init__()
input_channel = config.hidden_size
channels = config.decoder_hidden_size
strides = config.upsampling_ratios
# Add first conv layer
self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
# Add upsampling + MRF blocks
block = []
for stride_index, stride in enumerate(strides):
block += [DacDecoderBlock(config, stride, stride_index)]
self.block = nn.ModuleList(block)
output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
self.snake1 = Snake1d(output_dim)
self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
self.tanh = nn.Tanh()
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
hidden_state = self.tanh(hidden_state)
return hidden_state
class DacEncoder(nn.Module):
"""DAC Encoder"""
def __init__(self, config: DacConfig):
super().__init__()
strides = config.downsampling_ratios
# Create first convolution
self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
self.block = []
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride_index, stride in enumerate(strides):
stride_index = stride_index + 1
self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
self.block = nn.ModuleList(self.block)
d_model = config.encoder_hidden_size * 2**stride_index
self.snake1 = Snake1d(d_model)
self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class DacPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""
config_class = DacConfig
base_model_prefix = "dac"
main_input_name = "input_values"
def _init_weights(self, module):
if isinstance(module, nn.Conv1d):
nn.init.trunc_normal_(module.weight, std=0.02)
nn.init.constant_(module.bias, 0)
def apply_weight_norm(self):
for layer in self.quantizer.quantizers:
nn.utils.weight_norm(layer.in_proj)
nn.utils.weight_norm(layer.out_proj)
nn.utils.weight_norm(self.encoder.conv1)
nn.utils.weight_norm(self.encoder.conv2)
for layer in self.encoder.block:
nn.utils.weight_norm(layer.conv1)
nn.utils.weight_norm(layer.res_unit1.conv1)
nn.utils.weight_norm(layer.res_unit1.conv2)
nn.utils.weight_norm(layer.res_unit2.conv1)
nn.utils.weight_norm(layer.res_unit2.conv2)
nn.utils.weight_norm(layer.res_unit3.conv1)
nn.utils.weight_norm(layer.res_unit3.conv2)
nn.utils.weight_norm(self.decoder.conv1)
nn.utils.weight_norm(self.decoder.conv2)
for layer in self.decoder.block:
nn.utils.weight_norm(layer.conv_t1)
nn.utils.weight_norm(layer.res_unit1.conv1)
nn.utils.weight_norm(layer.res_unit1.conv2)
nn.utils.weight_norm(layer.res_unit2.conv1)
nn.utils.weight_norm(layer.res_unit2.conv2)
nn.utils.weight_norm(layer.res_unit3.conv1)
nn.utils.weight_norm(layer.res_unit3.conv2)
def remove_weight_norm(self):
for layer in self.quantizer.quantizers:
nn.utils.remove_weight_norm(layer.in_proj)
nn.utils.remove_weight_norm(layer.out_proj)
nn.utils.remove_weight_norm(self.encoder.conv1)
nn.utils.remove_weight_norm(self.encoder.conv2)
for layer in self.encoder.block:
nn.utils.remove_weight_norm(layer.conv1)
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
nn.utils.remove_weight_norm(self.decoder.conv1)
nn.utils.remove_weight_norm(self.decoder.conv2)
for layer in self.decoder.block:
nn.utils.remove_weight_norm(layer.conv_t1)
nn.utils.remove_weight_norm(layer.res_unit1.conv1)
nn.utils.remove_weight_norm(layer.res_unit1.conv2)
nn.utils.remove_weight_norm(layer.res_unit2.conv1)
nn.utils.remove_weight_norm(layer.res_unit2.conv2)
nn.utils.remove_weight_norm(layer.res_unit3.conv1)
nn.utils.remove_weight_norm(layer.res_unit3.conv2)
DAC_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`DacConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
DAC_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`).
Audio data to encode,
n_quantizers (`int`, *optional*):
Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The DAC (Descript Audio Codec) model.",
DAC_START_DOCSTRING,
)
class DacModel(DacPreTrainedModel):
def __init__(self, config: DacConfig):
super().__init__(config)
self.config = config
self.encoder = DacEncoder(config)
self.decoder = DacDecoder(config)
self.quantizer = DacResidualVectorQuantize(config)
self.bits_per_codebook = int(math.log2(self.config.codebook_size))
if 2**self.bits_per_codebook != self.config.codebook_size:
raise ValueError("The codebook_size must be a power of 2.")
# Initialize weights and apply final processing
self.post_init()
@replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC)
def encode(
self,
input_values: torch.Tensor,
n_quantizers: int = None,
return_dict: Optional[bool] = None,
):
"""
Encode given audio data and return quantized latent codes
Args:
input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
Input audio data to encode,
n_quantizers (int, *optional*):
Number of quantizers to use. If None, all quantizers are used. Default is None.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
quantized_representation = self.encoder(input_values)
quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
quantized_representation, n_quantizers
)
loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
if not return_dict:
return (loss, quantized_representation, audio_codes, projected_latents)
return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
@replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
def decode(
self,
quantized_representation: Optional[torch.Tensor],
audio_codes: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
):
"""Decode given latent codes and return audio data
Args:
quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`):
Quantized continuous representation of input.
audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
The codebook indices for each codebook, representing the quantized discrete
representation of the input. This parameter should be provided if you want
to decode directly from the audio codes (it will overwrite quantized_representation).
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
"""
if quantized_representation is None and audio_codes is None:
raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
return_dict = return_dict if return_dict is not None else self.config.return_dict
if audio_codes is not None:
quantized_representation = self.quantizer.from_codes(audio_codes)[0]
audio_values = self.decoder(quantized_representation).squeeze(1)
if not return_dict:
return (audio_values,)
return DacDecoderOutput(audio_values)
@add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values: torch.Tensor,
n_quantizers: int = None,
return_dict: Optional[bool] = None,
):
"""
Returns:
Examples:
```python
>>> from datasets import load_dataset, Audio
>>> from transformers import DacModel, AutoProcessor
>>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> model = DacModel.from_pretrained("descript/dac_16khz")
>>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
>>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
>>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
>>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
>>> encoder_outputs = model.encode(inputs["input_values"])
>>> # Get the intermediate audio codes
>>> audio_codes = encoder_outputs.audio_codes
>>> # Reconstruct the audio from its quantized representation
>>> audio_values = model.decode(encoder_outputs.quantized_representation)
>>> # or the equivalent with a forward pass
>>> audio_values = model(inputs["input_values"]).audio_values
```"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
length = input_values.shape[-1]
loss, quantized_representation, audio_codes, projected_latents = self.encode(
input_values, n_quantizers, return_dict=False
)
audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
if not return_dict:
return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)

View File

@ -2360,6 +2360,20 @@ class CvtPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class DacModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DacPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Data2VecAudioForAudioFrameClassification(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@ -0,0 +1,216 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the dac feature extractor."""
import itertools
import random
import unittest
import numpy as np
from transformers import DacFeatureExtractor
from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_torch_available():
import torch
global_rng = random.Random()
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
# Copied from transformers.tests.encodec.test_feature_extraction_dac.EncodecFeatureExtractionTester with Encodec->Dac
class DacFeatureExtractionTester(unittest.TestCase):
# Ignore copy
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size=1,
padding_value=0.0,
sampling_rate=16000,
hop_length=512,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.hop_length = hop_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.feature_size = feature_size
self.padding_value = padding_value
self.sampling_rate = sampling_rate
# Ignore copy
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"padding_value": self.padding_value,
"sampling_rate": self.sampling_rate,
"hop_length": self.hop_length,
}
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
audio_inputs = floats_list((self.batch_size, self.max_seq_length))
else:
# make sure that inputs increase in size
audio_inputs = [
_flatten(floats_list((x, self.feature_size)))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
audio_inputs = [np.asarray(x) for x in audio_inputs]
return audio_inputs
@require_torch
# Copied from transformers.tests.encodec.test_feature_extraction_dac.EnCodecFeatureExtractionTest with Encodec->Dac
class DacFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = DacFeatureExtractor
def setUp(self):
self.feat_extract_tester = DacFeatureExtractionTester(self)
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
# Test not batched input
encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_double_precision_pad(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_audio_inputs = np.random.rand(100).astype(np.float64)
py_audio_inputs = np_audio_inputs.tolist()
for inputs in [py_audio_inputs, np_audio_inputs]:
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_values.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in audio_samples]
def test_integration(self):
# fmt: off
EXPECTED_INPUT_VALUES = torch.tensor(
[ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03,
1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04,
9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03,
1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04,
1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03,
7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03,
8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04,
1.0986328e-03, 2.1057129e-03]
)
# fmt: on
input_audio = self._load_datasamples(1)
feature_extractor = DacFeatureExtractor()
input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"]
self.assertEqual(input_values.shape, (1, 1, 93696))
self.assertTrue(torch.allclose(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, atol=1e-4))
audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32)
self.assertTrue(torch.allclose(input_values[0, 0, -46:-16], audio_input_end, atol=1e-4))
# Ignore copy
@unittest.skip("The DAC model doesn't support stereo logic")
def test_integration_stereo(self):
pass
# Ignore copy
def test_truncation_and_padding(self):
input_audio = self._load_datasamples(2)
# would be easier if the stride was like
feature_extractor = DacFeatureExtractor()
# pad and trunc raise an error ?
with self.assertRaisesRegex(
ValueError,
"^Both padding and truncation were set. Make sure you only set one.$",
):
truncated_outputs = feature_extractor(
input_audio, padding="max_length", truncation=True, return_tensors="pt"
).input_values
# force truncate to max_length
truncated_outputs = feature_extractor(
input_audio, truncation=True, max_length=48000, return_tensors="pt"
).input_values
self.assertEqual(truncated_outputs.shape, (2, 1, 48128))
# pad:
padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values
self.assertEqual(padded_outputs.shape, (2, 1, 93696))
# force pad to max length
truncated_outputs = feature_extractor(
input_audio, padding="max_length", max_length=100000, return_tensors="pt"
).input_values
self.assertEqual(truncated_outputs.shape, (2, 1, 100352))
# force no pad
with self.assertRaisesRegex(
ValueError,
"^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$",
):
truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values
truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values
self.assertEqual(truncated_outputs.shape, (1, 1, 93680))

View File

@ -0,0 +1,749 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Dac model."""
import inspect
import os
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
from datasets import Audio, load_dataset
from transformers import AutoProcessor, DacConfig, DacModel
from transformers.testing_utils import is_torch_available, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
@require_torch
# Copied from transformers.tests.encodec.test_modeling_encodec.EncodecModelTester with Encodec->Dac
class DacModelTester:
# Ignore copy
def __init__(
self,
parent,
batch_size=3,
num_channels=1,
is_training=False,
intermediate_size=1024,
encoder_hidden_size=16,
downsampling_ratios=[2, 4, 4],
decoder_hidden_size=16,
n_codebooks=6,
codebook_size=512,
codebook_dim=4,
quantizer_dropout=0.0,
commitment_loss_weight=0.25,
codebook_loss_weight=1.0,
sample_rate=16000,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.is_training = is_training
self.intermediate_size = intermediate_size
self.sample_rate = sample_rate
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_hidden_size = decoder_hidden_size
self.n_codebooks = n_codebooks
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.quantizer_dropout = quantizer_dropout
self.commitment_loss_weight = commitment_loss_weight
self.codebook_loss_weight = codebook_loss_weight
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def prepare_config_and_inputs_for_model_class(self, model_class):
input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
config = self.get_config()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
# Ignore copy
def get_config(self):
return DacConfig(
encoder_hidden_size=self.encoder_hidden_size,
downsampling_ratios=self.downsampling_ratios,
decoder_hidden_size=self.decoder_hidden_size,
n_codebooks=self.n_codebooks,
codebook_size=self.codebook_size,
codebook_dim=self.codebook_dim,
quantizer_dropout=self.quantizer_dropout,
commitment_loss_weight=self.commitment_loss_weight,
codebook_loss_weight=self.codebook_loss_weight,
)
# Ignore copy
def create_and_check_model_forward(self, config, inputs_dict):
model = DacModel(config=config).to(torch_device).eval()
input_values = inputs_dict["input_values"]
result = model(input_values)
self.parent.assertEqual(result.audio_values.shape, (self.batch_size, self.intermediate_size))
@require_torch
# Copied from transformers.tests.encodec.test_modeling_encodec.EncodecModelTest with Encodec->Dac
class DacModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DacModel,) if is_torch_available() else ()
is_encoder_decoder = True
test_pruning = False
test_headmasking = False
test_resize_embeddings = False
pipeline_model_mapping = {"feature-extraction": DacModel} if is_torch_available() else {}
input_name = "input_values"
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
# model does not have attention and does not support returning hidden states
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if "output_attentions" in inputs_dict:
inputs_dict.pop("output_attentions")
if "output_hidden_states" in inputs_dict:
inputs_dict.pop("output_hidden_states")
return inputs_dict
def setUp(self):
self.model_tester = DacModelTester(self)
self.config_tester = ConfigTester(
self, config_class=DacConfig, hidden_size=37, common_properties=[], has_text_modality=False
)
def test_config(self):
self.config_tester.run_common_tests()
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
# Ignore copy
expected_arg_names = ["input_values", "n_quantizers", "return_dict"]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@unittest.skip("The DacModel is not transformers based, thus it does not have `inputs_embeds` logics")
def test_inputs_embeds(self):
pass
@unittest.skip("The DacModel is not transformers based, thus it does not have `inputs_embeds` logics")
def test_model_get_set_embeddings(self):
pass
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `attention` logic")
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `attention` logic")
def test_torchscript_output_attentions(self):
pass
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `hidden_states` logic")
def test_torchscript_output_hidden_state(self):
pass
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
main_input_name = model_class.main_input_name
try:
main_input = inputs[main_input_name]
model(main_input)
traced_model = torch.jit.trace(model, main_input)
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
if layer_name in loaded_model_state_dict:
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
# (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry()
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `attention` logic")
def test_attention_outputs(self):
pass
@unittest.skip("The DacModel is not transformers based, thus it does not have the usual `hidden_states` logic")
def test_hidden_states_output(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
pass
@unittest.skip("No support for low_cpu_mem_usage=True.")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
pass
def test_determinism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_determinism(first, second):
# outputs are not tensors but list (since each sequence don't have the same frame_length)
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
if isinstance(first, tuple) and isinstance(second, tuple):
for tensor1, tensor2 in zip(first, second):
check_determinism(tensor1, tensor2)
else:
check_determinism(first, second)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
# Ignore copy
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
uniform_init_parms = ["conv", "in_proj", "out_proj", "codebook"]
if param.requires_grad:
if any(x in name for x in uniform_init_parms):
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_identity_shortcut(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_conv_shortcut = False
self.model_tester.create_and_check_model_forward(config, inputs_dict)
def normalize(arr):
norm = np.linalg.norm(arr)
normalized_arr = arr / norm
return normalized_arr
def compute_rmse(arr1, arr2):
arr1_normalized = normalize(arr1)
arr2_normalized = normalize(arr2)
return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
@slow
@require_torch
class DacIntegrationTest(unittest.TestCase):
def test_integration_16khz(self):
expected_rmse = 0.004
expected_encoder_sums_dict = {
"loss": 24.8596,
"quantized_representation": -0.0745,
"audio_codes": 504.0948,
"projected_latents": 0.0682,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_name = "dac_16khz"
model_id = "descript/{}".format(model_name)
model = DacModel.from_pretrained(model_id, force_download=True).to(torch_device).eval()
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[0]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
# make sure audio encoded codes are correct
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
arr = inputs["input_values"][0].cpu().numpy()
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
arr_cut = arr[0, :max_length].copy()
arr_enc_dec_cut = arr_enc_dec[:max_length].copy()
# make sure audios are more or less equal
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
self.assertTrue(rmse < expected_rmse)
def test_integration_24khz(self):
expected_rmse = 0.0039
expected_encoder_output_dict = {
"quantized_representation": torch.tensor([0.9807, 2.8212, 5.2514, 2.7241, 1.0426]),
"audio_codes": torch.tensor([919, 919, 234, 777, 234]),
"projected_latents": torch.tensor([-4.7822, -5.0046, -4.5574, -5.0363, -5.4271]),
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_name = "dac_24khz"
model_id = "descript/{}".format(model_name)
model = DacModel.from_pretrained(model_id, force_download=True).to(torch_device).eval()
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[0]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])
expected_quantized_representation = encoder_outputs["quantized_representation"][0, 0, :5].cpu()
expected_audio_codes = encoder_outputs["audio_codes"][0, 0, :5].cpu()
expected_projected_latents = encoder_outputs["projected_latents"][0, 0, :5].cpu()
# make sure values are correct for audios slices
self.assertTrue(
torch.allclose(
expected_quantized_representation,
expected_encoder_output_dict["quantized_representation"],
atol=1e-3,
)
)
self.assertTrue(
torch.allclose(expected_audio_codes, expected_encoder_output_dict["audio_codes"], atol=1e-3)
)
self.assertTrue(
torch.allclose(
expected_projected_latents, expected_encoder_output_dict["projected_latents"], atol=1e-3
)
)
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
arr = inputs["input_values"][0].cpu().numpy()
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
arr_cut = arr[0, :max_length].copy()
arr_enc_dec_cut = arr_enc_dec[:max_length].copy()
# make sure audios are more or less equal
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
self.assertTrue(rmse < expected_rmse)
def test_integration_44khz(self):
expected_rmse = 0.002
expected_encoder_sums_dict = {
"loss": 34.3612,
"quantized_representation": 0.0078,
"audio_codes": 509.6812,
"projected_latents": -0.1054,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_name = "dac_44khz"
model_id = "descript/{}".format(model_name)
model = DacModel.from_pretrained(model_id).to(torch_device).eval()
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[0]["audio"]["array"]
inputs = processor(
raw_audio=audio_sample,
sampling_rate=processor.sampling_rate,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
# make sure audio encoded codes are correct
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
arr = inputs["input_values"][0].cpu().numpy()
arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
arr_cut = arr[0, :max_length].copy()
arr_enc_dec_cut = arr_enc_dec[:max_length].copy()
# make sure audios are more or less equal
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
self.assertTrue(rmse < expected_rmse)
def test_integration_batch_16khz(self):
expected_rmse = 0.002
expected_encoder_sums_dict = {
"loss": 20.3913,
"quantized_representation": -0.0538,
"audio_codes": 487.8470,
"projected_latents": 0.0237,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_name = "dac_16khz"
model_id = "descript/{}".format(model_name)
model = DacModel.from_pretrained(model_id).to(torch_device)
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_samples = [np.array([audio_sample["array"]])[0] for audio_sample in librispeech_dummy[-2:]["audio"]]
inputs = processor(
raw_audio=audio_samples,
sampling_rate=processor.sampling_rate,
truncation=False,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
encoder_outputs_mean = torch.tensor([v.float().mean().item() for v in encoder_outputs.to_tuple()])
# make sure audio encoded codes are correct
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
arr = inputs["input_values"].cpu().numpy()
arr_enc_dec = input_values_enc_dec.cpu().numpy()
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
arr_cut = arr[:, 0, :max_length].copy()
arr_enc_dec_cut = arr_enc_dec[:, :max_length].copy()
# make sure audios are more or less equal
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
self.assertTrue(rmse < expected_rmse)
def test_integration_batch_24khz(self):
expected_rmse = 0.002
expected_encoder_sums_dict = {
"loss": 24.2309,
"quantized_representation": 0.0520,
"audio_codes": 510.2700,
"projected_latents": -0.0076,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_name = "dac_24khz"
model_id = "descript/{}".format(model_name)
model = DacModel.from_pretrained(model_id).to(torch_device)
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_samples = [np.array([audio_sample["array"]])[0] for audio_sample in librispeech_dummy[-2:]["audio"]]
inputs = processor(
raw_audio=audio_samples,
sampling_rate=processor.sampling_rate,
truncation=False,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
# make sure audio encoded codes are correct
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
arr = inputs["input_values"].cpu().numpy()
arr_enc_dec = input_values_enc_dec.cpu().numpy()
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
arr_cut = arr[:, 0, :max_length].copy()
arr_enc_dec_cut = arr_enc_dec[:, :max_length].copy()
# make sure audios are more or less equal
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
self.assertTrue(rmse < expected_rmse)
def test_integration_batch_44khz(self):
expected_rmse = 0.001
expected_encoder_sums_dict = {
"loss": 25.9233,
"quantized_representation": 0.0013,
"audio_codes": 528.5620,
"projected_latents": -0.1194,
}
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
model_name = "dac_44khz"
model_id = "descript/{}".format(model_name)
model = DacModel.from_pretrained(model_id).to(torch_device)
processor = AutoProcessor.from_pretrained(model_id)
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_samples = [np.array([audio_sample["array"]])[0] for audio_sample in librispeech_dummy[-2:]["audio"]]
inputs = processor(
raw_audio=audio_samples,
sampling_rate=processor.sampling_rate,
truncation=False,
return_tensors="pt",
).to(torch_device)
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"])
expected_encoder_sums = torch.tensor(list(expected_encoder_sums_dict.values()), dtype=torch.float32)
encoder_outputs_mean = torch.tensor([v.float().mean().cpu().item() for v in encoder_outputs.to_tuple()])
# make sure audio encoded codes are correct
self.assertTrue(torch.allclose(encoder_outputs_mean, expected_encoder_sums, atol=1e-3))
_, quantized_representation, _, _ = encoder_outputs.to_tuple()
input_values_dec = model.decode(quantized_representation)[0]
input_values_enc_dec = model(inputs["input_values"])[1]
# make sure forward and decode gives same result
self.assertTrue(torch.allclose(input_values_dec, input_values_enc_dec, atol=1e-3))
arr = inputs["input_values"].cpu().numpy()
arr_enc_dec = input_values_enc_dec.cpu().numpy()
max_length = min(arr_enc_dec.shape[-1], arr.shape[-1])
arr_cut = arr[:, 0, :max_length].copy()
arr_enc_dec_cut = arr_enc_dec[:, :max_length].copy()
# make sure audios are more or less equal
rmse = compute_rmse(arr_cut, arr_enc_dec_cut)
self.assertTrue(rmse < expected_rmse)