Add Dia model (#38405)

* add dia model

* add tokenizer files

* cleanup some stuff

* brut copy paste code

* rough cleanup of the modeling code

* nuke some stuff

* more nuking

* more cleanups

* updates

* add mulitLayerEmbedding vectorization

* nits

* more modeling simplifications

* updates

* update rope

* update rope

* just fixup

* update configuration files

* more cleanup!

* default config values

* update

* forgotten comma

* another comma!

* update, more cleanups

* just more nits

* more config cleanups

* time for the encoder

* fix

* sa=mall nit

* nits

* n

* refacto a bit

* cleanup

* update cv scipt

* fix last issues

* fix last nits

* styling

* small fixes

* just run 1 generation

* fixes

* nits

* fix conversion

* fix

* more fixes

* full generate

* ouf!

* fixes!

* updates

* fix

* fix cvrt

* fixup

* nits

* delete wrong test

* update

* update

* test tokenization

* let's start changing things bit by bit - fix encoder step

* removing custom generation, moving to GenerationMixin

* add encoder decoder attention masks for generation

* mask changes, correctness checked against ad29837 in dia repo

* refactor a bit already --> next cache

* too important not to push :)

* minimal cleanup + more todos

* make main overwrite modeling utils

* add cfg filter & eos filter

* add eos countdown & delay pattern

* update eos countdown

* add max step eos countdown

* fix tests

* fix some things

* fix generation with testing

* move cfg & eos stuff to logits processor

* make RepetitionPenaltyLogitsProcessor flexible

- can accept 3D scores like (batch_size, channel, vocab)

* fix input_ids concatenation dimension in GenerationMixin for flexibility

* Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty classes; refactor logits processing in DiaForConditionalGeneration to utilize new configurations and improve flexibility.

* Add stopping criteria

* refactor

* move delay pattern from processor to modeling like musicgen.

- add docs
- change eos countdown to eos delay pattern

* fix processor & fix tests

* refactor types

* refactor imports

* format code

* fix docstring to pass ci

* add docstring to DiaConfig & add DiaModel to test

* fix docstring

* add docstring

* fix some bugs

* check

* porting / merging results from other branch - IMPORTANT: it very likely breaks generation, the goal is to have a proper forward path first

* experimental testing of left padding for first channel

* whoops

* Fix merge to make generation work

* fix cfg filter

* add position ids

* add todos, break things

* revert changes to generation --> we will force 2d but go 3d on custom stuff

* refactor a lot, change prepare decoder ids to work with left padding (needs testing), add todos

* some first fixes to get to 10. in generation

* some more generation fixes / adjustment

* style + rope fixes

* move cfg out, simplify a few things, more todos

* nit

* start working on custom logit processors

* nit

* quick fixes

* cfg top k

* more refactor of logits processing, needs a decision if gen config gets the new attributes or if we move it to config or similar

* lets keep changes to core code minimal, only eos scaling is questionable atm

* simpler eos delay logits processor

* that was for debugging :D

* proof of concept rope

* small fix on device mismatch

* cfg fixes + delay logits max len

* transformers rope

* modular dia

* more cleanup

* keep modeling consistently 3D, generate handles 2D internally

* decoder starts with bos if nothing

* post processing prototype

* style

* lol

* force sample / greedy + fixes on padding

* style

* fixup tokenization

* nits

* revert

* start working on dia tests

* fix a lot of tests

* more test fixes

* nit

* more test fixes + some features to simplify code more

* more cleanup

* forgot that one

* autodocs

* small consistency fixes

* fix regression

* small fixes

* dia feature extraction

* docs

* wip processor

* fix processor order

* processing goes brrr

* transpose before

* small fix

* fix major bug but needs now a closer look into the custom processors esp cfg

* small thing on logits

* nits

* simplify indices and shifts

* add simpler version of padding tests back (temporarily)

* add logit processor tests

* starting tests on processor

* fix mask application during generation

* some fixes on the weights conversion

* style + fixup logits order

* simplify conversion

* nit

* remove padding tests

* nits on modeling

* hmm

* fix tests

* trigger

* probably gonna be reverted, just a quick design around audio tokenizer

* fixup typing

* post merge + more typing

* initial design for audio tokenizer

* more design changes

* nit

* more processor tests and style related things

* add to init

* protect import

* not sure why tbh

* add another protect

* more fixes

* wow

* it aint stopping :D

* another missed type issue

* ...

* change design around audio tokenizer to prioritize init and go for auto - in regards to the review

* change to new causal mask function + docstrings

* change ternary

* docs

* remove todo, i dont think its essential tbh

* remove pipeline as current pipelines do not fit in the current scheme, same as csm

* closer to wrapping up the processor

* text to audio, just for demo purposes (will likely be reverted)

* check if it's this

* save audio function

* ensure no grad

* fixes on prefixed audio, hop length is used via preprocess dac, device fixes

* integration tests (tested locally on a100) + some processor utils / fixes

* style

* nits

* another round of smaller things

* docs + some fixes (generate one might be big)

* msytery solved

* small fix on conversion

* add abstract audio tokenizer, change init check to abstract class

* nits

* update docs + fix some processing :D

* change inheritance scheme for audio tokenizer

* delete dead / unnecessary code in copied generate loop

* last nits on new pipeline behavior (+ todo on tests) + style

* trigger

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Vasqu <antonprogamer@gmail.com>
This commit is contained in:
Jaeyong Sung 2025-06-26 20:04:23 +09:00 committed by GitHub
parent 5995cfa0a0
commit 583db52bc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 5733 additions and 29 deletions

View File

@ -839,6 +839,8 @@
title: CSM
- local: model_doc/dac
title: dac
- local: model_doc/dia
title: Dia
- local: model_doc/encodec
title: EnCodec
- local: model_doc/fastspeech2_conformer

View File

@ -350,6 +350,10 @@ The following auto classes are available for the following audio tasks.
[[autodoc]] AutoModelForTextToWaveform
### AutoModelForAudioTokenization
[[autodoc]] AutoModelForAudioTokenization
## Multimodal
The following auto classes are available for the following multimodal tasks.

View File

@ -0,0 +1,162 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Dia
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
## Overview
Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs).
It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing.
Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning).
**Model Architecture:**
Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as
rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while
for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook
tokens and decodes them back into audio.
## Usage Tips
### Generation with Text
```python
from transformers import AutoProcessor, DiaForConditionalGeneration
torch_device = "cuda"
model_checkpoint = "buttercrab/dia-v1-1.6b"
text = ["[S1] Dia is an open weights text to dialogue model."]
processor = AutoProcessor.from_pretrained(model_checkpoint)
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s
# save audio to a file
outputs = processor.batch_decode(outputs)
processor.save_audio(outputs, "example.wav")
```
### Generation with Text and Audio (Voice Cloning)
```python
from datasets import load_dataset, Audio
from transformers import AutoProcessor, DiaForConditionalGeneration
torch_device = "cuda"
model_checkpoint = "buttercrab/dia-v1-1.6b"
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=44100))
audio = ds[-1]["audio"]["array"]
# text is a transcript of the audio + additional text you want as new audio
text = ["[S1] I know. It's going to save me a lot of money, I hope. [S2] I sure hope so for you."]
processor = AutoProcessor.from_pretrained(model_checkpoint)
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"])
model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s
# retrieve actually generated audio and save to a file
outputs = processor.batch_decode(outputs, audio_prompt_len=prompt_len)
processor.save_audio(outputs, "example_with_audio.wav")
```
### Training
```python
from datasets import load_dataset, Audio
from transformers import AutoProcessor, DiaForConditionalGeneration
torch_device = "cuda"
model_checkpoint = "buttercrab/dia-v1-1.6b"
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
ds = ds.cast_column("audio", Audio(sampling_rate=44100))
audio = ds[-1]["audio"]["array"]
# text is a transcript of the audio
text = ["[S1] I know. It's going to save me a lot of money, I hope."]
processor = AutoProcessor.from_pretrained(model_checkpoint)
inputs = processor(
text=text,
audio=audio,
generation=False,
output_labels=True,
padding=True,
return_tensors="pt"
).to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
out = model(**inputs)
out.loss.backward()
```
This model was contributed by [Jaeyong Sung](https://huggingface.co/buttercrab), [Arthur Zucker](https://huggingface.co/ArthurZ),
and [Anton Vlasjuk](https://huggingface.co/AntonV). The original code can be found [here](https://github.com/nari-labs/dia/).
## DiaConfig
[[autodoc]] DiaConfig
## DiaDecoderConfig
[[autodoc]] DiaDecoderConfig
## DiaEncoderConfig
[[autodoc]] DiaEncoderConfig
## DiaTokenizer
[[autodoc]] DiaTokenizer
- __call__
## DiaFeatureExtractor
[[autodoc]] DiaFeatureExtractor
- __call__
## DiaProcessor
[[autodoc]] DiaProcessor
- __call__
- batch_decode
- decode
## DiaModel
[[autodoc]] DiaModel
- forward
## DiaForConditionalGeneration
[[autodoc]] DiaForConditionalGeneration
- forward
- generate

View File

@ -271,7 +271,6 @@ class PretrainedConfig(PushToHubMixin):
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.sep_token_id = kwargs.pop("sep_token_id", None)
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
# task specific arguments

View File

@ -2975,3 +2975,224 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
The expected mean g-value for watermarked text.
"""
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original
`ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall
calculation, e.g. conditioned logits centered, and an additional top k selection
option.
<Tip warning={true}>
This logits processor is exclusively compatible with
[Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia)
</Tip>
Args:
guidance_scale (float):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
guidance_top_k (int, *optional*):
The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep
the logits of the combined CFG output, but the conditioned output only.
"""
def __init__(self, guidance_scale: float, guidance_top_k: Optional[int] = None):
if guidance_scale > 1:
self.guidance_scale = guidance_scale
else:
raise ValueError(
"Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
f"{guidance_scale}."
)
self.guidance_top_k = guidance_top_k
if self.guidance_top_k is not None and self.guidance_top_k < 1:
raise ValueError(
f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}"
)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# simple check to make sure we have compatible batch sizes between our
# logits scores (cond + uncond) and input ids (cond only)
if scores.shape[0] != 2 * input_ids.shape[0]:
raise ValueError(
f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
)
# Base CFG with center on cond_logits
unguided_bsz = scores.shape[0] // 2
cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale
# Optional CFG top k filtering
if self.guidance_top_k is not None:
# Create top k based on the combined CFG output
_, top_k_indices = torch.topk(scores_processed, k=self.guidance_top_k, dim=-1)
top_k_mask = torch.ones_like(scores_processed, dtype=torch.bool)
top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False)
# Only return conditioned logits with top k
scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf"))
return scores_processed
class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor):
r"""Specialized processor that ensures certain properties around EOS sampling:
1. Only channel 0 can generate EOS
2. If channel 0 has EOS with highest logit, it will be the only candidate
3. If channel 0 has EOS not with highest logit, it will be suppressed
2. and 3. are especially important in contexts where we allow sampling to guarantee the
respective tokens to be (not) sampled.
<Tip warning={true}>
This logits processor is exclusively compatible with
[Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
</Tip>
Args:
num_channels (`int`):
Number of audio codebooks. Simplifies access to the first channel on the logits.
eos_token_id (`int`):
The id of *end-of-sequence* token.
"""
def __init__(self, num_channels: int, eos_token_id: int):
if num_channels < 1:
raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.")
if eos_token_id < 1:
raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.")
self.num_channels = num_channels
self.eos_id = eos_token_id
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Reshape for easier channel indexing [B, C, V]
scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
# EOS filter
# 1. Condition: Only the first channel can generate the EOS token
# Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...)
# (Assumes them to be greater than audio eos token position)
scores[:, 1:, self.eos_id :] = torch.full_like(
scores[:, 1:, self.eos_id :],
fill_value=-float("inf"),
)
scores[:, 0, self.eos_id + 1 :] = torch.full_like(
scores[:, 0, self.eos_id + 1 :],
fill_value=-float("inf"),
)
# 2+3 Conditions: Force/Suppress EOS if (not) highest logit
# Reshape back to original shape
scores = scores.view(-1, scores.shape[-1])
# Sample highest tokens
top_logit_indices = torch.argmax(scores, dim=-1)
# 2. Force EOS
eos_highest_mask = top_logit_indices == self.eos_id
mask_eos_highest = torch.zeros_like(scores, dtype=torch.bool)
mask_eos_highest[eos_highest_mask, : self.eos_id] = True
scores = scores.masked_fill(mask_eos_highest, -float("inf"))
# 3. Suppress EOS
eos_not_highest_mask = top_logit_indices != self.eos_id
mask_eos_unless_highest = torch.zeros_like(scores, dtype=torch.bool)
mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True
scores = scores.masked_fill(mask_eos_unless_highest, -float("inf"))
return scores
class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor):
r"""Special logits processor to handle the generation of the EOS token in Dia.
This is due to the fact that Dia does not allow the generation of EOS in all
channels except the first channel (C0).
Hence, based on the delay pattern, an EOS is forced after the respective delays
in the channels. For example, if the delay pattern is [0, 2, 3, 4]:
s s+1 s+2 s+3 s+4 s+5 ...
| | | | | |
C0: EOS PAD PAD PAD PAD PAD ...
C1: x x EOS PAD PAD PAD ...
C2: x x x EOS PAD PAD ...
C3: x x x x EOS PAD ...
If the first channel generated EOS at step s, channels Cx are forced to generate
theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are
handled by the `EosTokenCriteria` when an EOS has been detected.
<Tip warning={true}>
This logits processor is exclusively compatible with
[Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
</Tip>
Args:
delay_pattern (`List[int]`):
The delays per channel in the audio codebooks.
eos_token_id (`int`):
The id of *end-of-sequence* token.
max_generation_len (`int`):
The max sequence length that can be generated.
device (`str`, *optional*, defaults to `"cpu"`):
The device to allocate the tensors on.
"""
def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int, device: str = "cpu"):
self.num_channels = len(delay_pattern)
# Update during first iteration
self.active_batches = None
self.delay_pattern = torch.tensor(delay_pattern, device=device, dtype=torch.int)[None, :]
self.eos_token_id = eos_token_id
self.max_generation_len = max_generation_len - max(delay_pattern) - 1
self.device = device
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Reshape for easier channel indexing [B, C, V]
scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
# Initialize / expand values on first iteration
if self.active_batches is None:
self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1)
self.active_batches = torch.zeros(size=(scores.shape[0],), device=self.device, dtype=torch.bool)
# Check if eos has been generated in any batch
channel_generated_eos = torch.argmax(scores, dim=-1)[:, 0] == self.eos_token_id
# Check if max len has been reached
reached_max_len = input_ids.shape[1] == self.max_generation_len
# Update active batches
self.active_batches |= channel_generated_eos
self.active_batches |= reached_max_len
# Find channels that need to force eos
forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0)
# Use indexing to avoid issues on all `False` by having empty tensors in that case
idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True)
# Force eos if delay is kicking in
scores[idx_bsz, idx_channel, :] = -float("inf")
scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0
# Reshape back to [B * C, V]
scores = scores.reshape(-1, scores.shape[-1])
# Update amount of delay left for each channel
self.delay_pattern -= self.active_batches[:, None].int()
return scores

View File

@ -26,6 +26,7 @@ import re
import shutil
import tempfile
import warnings
from abc import abstractmethod
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
@ -5884,3 +5885,26 @@ class AttentionInterface(GeneralInterface):
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
class PreTrainedAudioTokenizerBase(PreTrainedModel):
"""
Class that additionally defines the behavior of any `audio_tokenizer` to be added.
Characteristic for any of them:
1. Encode raw audio into discrete audio codebooks (with x channels)
2. Decode from discrete audio codebooks back to raw audio
It is possible that they can decode in different ways given a different representation
but they are forced to support 2. nonetheless, e.g. see `DAC`.
"""
@abstractmethod
def encode(self, input_values: torch.Tensor, *args, **kwargs):
"""
Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels)
"""
pass
@abstractmethod
def decode(self, audio_codes: torch.Tensor, *args, **kwargs):
"""Decode from discrete audio codebooks back to raw audio"""
pass

View File

@ -88,6 +88,7 @@ if TYPE_CHECKING:
from .depth_anything import *
from .depth_pro import *
from .detr import *
from .dia import *
from .dialogpt import *
from .diffllama import *
from .dinat import *

View File

@ -106,6 +106,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("depth_pro", "DepthProConfig"),
("deta", "DetaConfig"),
("detr", "DetrConfig"),
("dia", "DiaConfig"),
("diffllama", "DiffLlamaConfig"),
("dinat", "DinatConfig"),
("dinov2", "Dinov2Config"),
@ -478,6 +479,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("depth_pro", "DepthPro"),
("deta", "DETA"),
("detr", "DETR"),
("dia", "Dia"),
("dialogpt", "DialoGPT"),
("diffllama", "DiffLlama"),
("dinat", "DiNAT"),

View File

@ -55,6 +55,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("deformable_detr", "DeformableDetrFeatureExtractor"),
("deit", "DeiTFeatureExtractor"),
("detr", "DetrFeatureExtractor"),
("dia", "DiaFeatureExtractor"),
("dinat", "ViTFeatureExtractor"),
("donut-swin", "DonutFeatureExtractor"),
("dpt", "DPTFeatureExtractor"),

View File

@ -99,6 +99,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("depth_pro", "DepthProModel"),
("deta", "DetaModel"),
("detr", "DetrModel"),
("dia", "DiaModel"),
("diffllama", "DiffLlamaModel"),
("dinat", "DinatModel"),
("dinov2", "Dinov2Model"),
@ -472,6 +473,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
("deberta-v2", "DebertaV2ForMaskedLM"),
("dia", "DiaForConditionalGeneration"),
("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForMaskedLM"),
("encoder-decoder", "EncoderDecoderModel"),
@ -1059,6 +1061,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("dia", "DiaForConditionalGeneration"),
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
("moonshine", "MoonshineForConditionalGeneration"),
@ -1629,6 +1632,12 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
[
("dac", "DacModel"),
]
)
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
@ -1737,6 +1746,8 @@ MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
class AutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
@ -2034,6 +2045,15 @@ class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
class AutoModelForAudioTokenization(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
AutoModelForAudioTokenization = auto_class_update(
AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
)
class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
@ -2059,6 +2079,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead):
__all__ = [
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
"MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
"MODEL_FOR_AUDIO_XVECTOR_MAPPING",
"MODEL_FOR_BACKBONE_MAPPING",
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
@ -2106,6 +2127,7 @@ __all__ = [
"AutoBackbone",
"AutoModelForAudioClassification",
"AutoModelForAudioFrameClassification",
"AutoModelForAudioTokenization",
"AutoModelForAudioXVector",
"AutoModelForCausalLM",
"AutoModelForCTC",

View File

@ -61,6 +61,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("colqwen2", "ColQwen2Processor"),
("dia", "DiaProcessor"),
("emu3", "Emu3Processor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),

View File

@ -177,6 +177,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("dia", ("DiaTokenizer", None)),
(
"diffllama",
(

View File

@ -23,7 +23,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedAudioTokenizerBase
from ...utils import ModelOutput, auto_docstring
from .configuration_dac import DacConfig
@ -471,7 +471,7 @@ class DacEncoder(nn.Module):
@auto_docstring
class DacPreTrainedModel(PreTrainedModel):
class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
config_class = DacConfig
base_model_prefix = "dac"
main_input_name = "input_values"

View File

@ -0,0 +1,31 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_dia import *
from .feature_extraction_dia import *
from .generation_dia import *
from .modeling_dia import *
from .processing_dia import *
from .tokenization_dia import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,376 @@
# coding=utf-8
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dia model configuration"""
from typing import Optional
from ...configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation
from ...utils import logging
logger = logging.get_logger(__name__)
class DiaEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
encoder according to the specified arguments, defining the encoder architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 16):
Number of key and value heads for each attention layer in the Transformer encoder.
head_dim (`int`, *optional*, defaults to 128):
Dimensionality of the attention head.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers.
vocab_size (`int`, *optional*, defaults to 256):
Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DiaModel`].
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"swish"` and `"gelu_new"` are supported.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
"""
model_type = "dia_encoder"
def __init__(
self,
max_position_embeddings: int = 1024,
num_hidden_layers: int = 12,
hidden_size: int = 1024,
num_attention_heads: int = 16,
num_key_value_heads: int = 16,
head_dim: int = 128,
intermediate_size: int = 4096,
norm_eps: float = 1e-5,
vocab_size: int = 256,
hidden_act: str = "silu",
rope_theta: float = 10000.0,
rope_scaling: Optional[dict] = None,
initializer_range: float = 0.02,
**kwargs,
):
self.max_position_embeddings = max_position_embeddings
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.norm_eps = norm_eps
self.vocab_size = vocab_size
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.initializer_range = initializer_range
super().__init__(**kwargs)
class DiaDecoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
decoder according to the specified arguments, defining the decoder architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
max_position_embeddings (`int`, *optional*, defaults to 3072):
The maximum sequence length that this model might ever be used with.
num_hidden_layers (`int`, *optional*, defaults to 18):
Number of hidden layers in the Transformer decoder.
hidden_size (`int`, *optional*, defaults to 2048):
Dimensionality of the decoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 4):
Number of key and value heads for each attention layer in the Transformer decoder.
head_dim (`int`, *optional*, defaults to 128):
Dimensionality of the attention head.
cross_num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each cross-attention layer in the Transformer decoder.
cross_head_dim (`int`, *optional*, defaults to 128):
Dimensionality of the cross-attention head.
cross_num_key_value_heads (`int`, *optional*, defaults to 16):
Number of key and value heads for each cross-attention layer in the Transformer decoder.
cross_hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the cross-attention layers.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers.
vocab_size (`int`, *optional*, defaults to 1028):
Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DiaModel`].
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
`"swish"` and `"gelu_new"` are supported.
num_channels (`int`, *optional*, defaults to 9):
Number of channels for the Dia decoder.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Indicating that this model is part of an encoder-decoder architecture.
"""
model_type = "dia_decoder"
def __init__(
self,
max_position_embeddings: int = 3072,
num_hidden_layers: int = 18,
hidden_size: int = 2048,
intermediate_size: int = 8192,
num_attention_heads: int = 16,
num_key_value_heads: int = 4,
head_dim: int = 128,
cross_num_attention_heads: int = 16,
cross_head_dim: int = 128,
cross_num_key_value_heads: int = 16,
cross_hidden_size: int = 1024,
norm_eps: float = 1e-5,
vocab_size: int = 1028,
hidden_act: str = "silu",
num_channels: int = 9,
rope_theta: float = 10000.0,
rope_scaling: Optional[dict] = None,
initializer_range: float = 0.02,
use_cache: bool = True,
is_encoder_decoder: bool = True,
**kwargs,
):
self.max_position_embeddings = max_position_embeddings
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.cross_num_key_value_heads = cross_num_key_value_heads
self.cross_num_attention_heads = cross_num_attention_heads
self.cross_head_dim = cross_head_dim
self.cross_hidden_size = cross_hidden_size
self.norm_eps = norm_eps
self.vocab_size = vocab_size
self.hidden_act = hidden_act
self.num_channels = num_channels
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
self.initializer_range = initializer_range
self.use_cache = use_cache
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
class DiaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the
[nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
encoder_config (`DiaEncoderConfig`, *optional*):
Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
decoder_config (`DiaDecoderConfig`, *optional*):
Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the normalization layers.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Indicating that this model uses an encoder-decoder architecture.
pad_token_id (`int`, *optional*, defaults to 1025):
Padding token id.
eos_token_id (`int`, *optional*, defaults to 1024):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 1026):
Beginning of stream token id.
delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Example:
```python
>>> from transformers import DiaConfig, DiaModel
>>> # Initializing a DiaConfig with default values
>>> configuration = DiaConfig()
>>> # Initializing a DiaModel (with random weights) from the configuration
>>> model = DiaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "dia"
keys_to_ignore_at_inference = ["past_key_values"]
sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
def __init__(
self,
encoder_config: Optional[DiaEncoderConfig] = None,
decoder_config: Optional[DiaDecoderConfig] = None,
norm_eps: float = 1e-5,
is_encoder_decoder: bool = True,
pad_token_id: int = 1025,
eos_token_id: int = 1024,
bos_token_id: int = 1026,
delay_pattern: Optional[list[int]] = None,
initializer_range: float = 0.02,
use_cache: bool = True,
**kwargs,
):
if isinstance(encoder_config, dict):
encoder_config = DiaEncoderConfig(**encoder_config)
if isinstance(decoder_config, dict):
decoder_config = DiaDecoderConfig(**decoder_config)
self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
self.norm_eps = norm_eps
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
self.initializer_range = initializer_range
self.use_cache = use_cache
assert self.decoder_config.num_channels == len(self.delay_pattern), (
"Number of channels must match delay pattern length."
)
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
def get_text_config(self, decoder=False):
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
return self.decoder_config
__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]

View File

@ -0,0 +1,199 @@
# coding=utf-8
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts a Dia model in Nari Labs format to Hugging Face format."""
import argparse
import os
import re
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from transformers import (
DacModel,
DiaConfig,
DiaFeatureExtractor,
DiaForConditionalGeneration,
DiaProcessor,
DiaTokenizer,
GenerationConfig,
)
from transformers.utils.import_utils import _is_package_available
# Provide just the list of layer keys you want to fix
shape_mappings = [
"encoder.layers.*.mlp.gate_up_proj.weight",
"encoder.layers.*.mlp.down_proj.weight",
"encoder.layers.*.self_attention.q_proj.weight",
"encoder.layers.*.self_attention.k_proj.weight",
"encoder.layers.*.self_attention.v_proj.weight",
"encoder.layers.*.self_attention.o_proj.weight",
"decoder.layers.*.mlp.gate_up_proj.weight",
"decoder.layers.*.mlp.down_proj.weight",
"decoder.layers.*.self_attention.q_proj.weight",
"decoder.layers.*.self_attention.k_proj.weight",
"decoder.layers.*.self_attention.v_proj.weight",
"decoder.layers.*.self_attention.o_proj.weight",
"decoder.layers.*.cross_attention.q_proj.weight",
"decoder.layers.*.cross_attention.k_proj.weight",
"decoder.layers.*.cross_attention.v_proj.weight",
"decoder.layers.*.cross_attention.o_proj.weight",
"decoder.logits_dense.weight",
]
# Provide renamings here
rename_mapping = {
"mlp.wo": "mlp.down_proj",
"mlp.wi_fused": "mlp.gate_up_proj",
}
def get_generation_config(config):
model_generation_config = GenerationConfig.from_model_config(config)
model_generation_config._from_model_config = False
model_generation_config.do_sample = True
model_generation_config.top_k = 45
model_generation_config.top_p = 0.95
model_generation_config.temperature = 1.2
model_generation_config.guidance_scale = 3.0
model_generation_config.max_length = 3072 # Decoder max length
return model_generation_config
def convert_dia_model_to_hf(checkpoint_path, verbose=False):
"""
Converts a Dia model in Nari Labs format to Hugging Face format.
Args:
checkpoint_path (`str`):
Path to the downloaded checkpoints.
verbose (`bool`, *optional*)
Whether to print information during conversion.
"""
# Download from HF Hub if checkpoint_path is None
checkpoint_path = snapshot_download(repo_id=checkpoint_path, allow_patterns=["*.pth", "*.safetensors"])
print(f"Downloaded checkpoint from Hugging Face Hub: {checkpoint_path}")
# Initialize base model with default config == 1.6B model
with torch.device("meta"):
hf_model = DiaForConditionalGeneration(config=DiaConfig())
hf_model_dict = hf_model.state_dict()
hf_model_keys = hf_model_dict.keys()
# Iterate through dir to catch all respective files - prefers safetensors but allows pt
files = os.listdir(checkpoint_path)
for file in files:
if file.endswith(".safetensors"):
load_function = load_file
elif file.endswith(".pth"):
load_function = torch.load
checkpoint_path = os.path.join(checkpoint_path, files[0])
nari_state_dict = load_function(checkpoint_path, "cpu")
# Conversion starts here
converted_state_dict = {}
embeddings = {}
for key, tensor in nari_state_dict.items():
# add prefix
key = "model." + key
# rename some weights
for original, rename in rename_mapping.items():
if original in key:
key = re.sub(original, rename, key)
# decoder multi channel
if "embeddings" in key:
embeddings_key = key.rsplit(".", 2)[0] + ".embed.weight"
if embeddings_key in embeddings:
embeddings[embeddings_key] += [tensor]
else:
embeddings[embeddings_key] = [tensor]
continue
elif re.sub(r"\d+", "*", key).removeprefix("model.") in shape_mappings:
# add exception to the head
if "logits_dense" in key:
key = re.sub("decoder.logits_dense", "logits_dense", key).removeprefix("model.")
# dense general
if key in hf_model_keys:
tensor_shape = tensor.shape
target_shape = hf_model_dict[key].shape
try:
tensor = tensor.reshape(target_shape[1], target_shape[0]).T
if verbose:
print(f"{key}: transpose reshaped from {tensor_shape} to {target_shape}")
except Exception as e:
print(f"WARNING: Could not reshape {key}: {e}")
converted_state_dict[key] = tensor
# Combining the embeddings as last step
embeddings = {k: torch.cat(v, dim=0) for k, v in embeddings.items()}
converted_state_dict.update(embeddings)
# Load converted weights into HF model
hf_model.load_state_dict(converted_state_dict, assign=True)
# Overwrite generation config
hf_model.generation_config = get_generation_config(DiaConfig())
return hf_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument(
"--checkpoint_path", type=str, default="nari-labs/Dia-1.6B", help="Path to the downloaded checkpoints"
)
parser.add_argument(
"--pytorch_dump_folder_path", default="AntonV/Dia-1.6B", type=str, help="Path to the output PyTorch model."
)
parser.add_argument(
"--convert_preprocessor",
type=bool,
default=True,
help="Whether or not the preprocessor (tokenizer + feature extractor) should be converted along with the model.",
)
parser.add_argument(
"--verbose",
type=bool,
default=True,
help="Whether or not to log information during conversion.",
)
args = parser.parse_args()
model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose)
if args.convert_preprocessor:
try:
if not _is_package_available("tiktoken"):
raise ModuleNotFoundError(
"""`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
)
except Exception as e:
print(e)
else:
processor = DiaProcessor(
DiaFeatureExtractor(sampling_rate=44100, hop_length=512),
DiaTokenizer(),
DacModel.from_pretrained("descript/dac_44khz"),
)
processor.save_pretrained(args.pytorch_dump_folder_path)
model.save_pretrained(args.pytorch_dump_folder_path)
print(f"Saved converted checkpoint to {args.pytorch_dump_folder_path}")

View File

@ -0,0 +1,183 @@
# coding=utf-8
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for Dia"""
from typing import Optional, Union
import numpy as np
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging
logger = logging.get_logger(__name__)
class DiaFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs an Dia feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
feature_size (`int`, *optional*, defaults to 1):
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used for padding.
hop_length (`int`, *optional*, defaults to 512):
Overlap length between successive windows.
"""
model_input_names = ["input_values", "n_quantizers"]
def __init__(
self,
feature_size: int = 1,
sampling_rate: int = 16000,
padding_value: float = 0.0,
hop_length: int = 512,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.hop_length = hop_length
def __call__(
self,
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s).
Args:
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
(`feature_size = 2`).
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
truncation (`bool`, *optional*, defaults to `False`):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
"Failing to do so can result in silent errors that might be hard to debug."
)
if padding and truncation:
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
elif padding is None:
# by default let's pad the inputs
padding = True
is_batched = bool(
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
)
if is_batched:
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
elif not is_batched and not isinstance(raw_audio, np.ndarray):
raw_audio = np.asarray(raw_audio, dtype=np.float32)
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
raw_audio = raw_audio.astype(np.float32)
# always return batch
if not is_batched:
raw_audio = [np.asarray(raw_audio).T]
# convert stereo to mono if necessary, unique to Dia
for idx, example in enumerate(raw_audio):
if self.feature_size == 2 and example.ndim == 2:
raw_audio[idx] = np.mean(example, -1)
# verify inputs are valid
for idx, example in enumerate(raw_audio):
if example.ndim > 2:
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
if self.feature_size == 1 and example.ndim != 1:
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
if self.feature_size == 2 and example.ndim != 1: # note the conversion before
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
input_values = BatchFeature({"input_values": raw_audio})
# temporarily treat it as if we were mono as we also convert stereo to mono
origingal_feature_size = self.feature_size
self.feature_size = 1
# normal padding on batch
padded_inputs = self.pad(
input_values,
max_length=max_length,
truncation=truncation,
padding=padding,
return_attention_mask=True,
pad_to_multiple_of=self.hop_length,
)
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
input_values = []
for example in padded_inputs.pop("input_values"):
if self.feature_size == 1:
example = example[..., None]
input_values.append(example.T)
padded_inputs["input_values"] = input_values
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
# rewrite back to original feature size
self.feature_size = origingal_feature_size
return padded_inputs
__all__ = ["DiaFeatureExtractor"]

View File

@ -0,0 +1,464 @@
# coding=utf-8
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
import torch
import torch.distributed as dist
from ...generation.logits_process import (
DiaClassifierFreeGuidanceLogitsProcessor,
DiaEOSChannelFilterLogitsProcessor,
DiaEOSDelayPatternLogitsProcessor,
LogitsProcessorList,
TemperatureLogitsWarper,
)
from ...generation.stopping_criteria import StoppingCriteriaList
from ...generation.streamers import BaseStreamer
from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_utils import PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
class DiaGenerationMixin(GenerationMixin):
# Indicates CFG which needs preparation to be properly handled by repeats
_uses_cfg = None
def _get_logits_processor(
self,
generation_config: GenerationConfig,
input_ids_seq_length: Optional[int] = None,
encoder_input_ids: torch.LongTensor = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
device: Optional[str] = None,
model_kwargs: Optional[dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
) -> LogitsProcessorList:
# Need either custom order or custom processor instead
# (Temporarily disabling those for the super function)
original_guidance_scale = generation_config.guidance_scale
original_temperature = generation_config.temperature
generation_config.guidance_scale = None
generation_config.temperature = None
# Get base processors and those we can integrate easily
custom_processors = LogitsProcessorList()
if original_temperature is not None and original_temperature != 1.0:
custom_processors.append(TemperatureLogitsWarper(original_temperature))
custom_processors.append(
DiaEOSChannelFilterLogitsProcessor(
num_channels=len(self.config.delay_pattern),
eos_token_id=self.config.eos_token_id,
)
)
merged_processors = super()._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=encoder_input_ids,
prefix_allowed_tokens_fn=None,
logits_processor=custom_processors,
device=device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# Custom processors we need at specific positions
if original_guidance_scale is not None and original_guidance_scale != 1:
cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
guidance_scale=original_guidance_scale,
guidance_top_k=generation_config.top_k,
)
merged_processors.insert(0, cfg_processor)
merged_processors.append(
DiaEOSDelayPatternLogitsProcessor(
delay_pattern=self.config.delay_pattern,
eos_token_id=self.config.eos_token_id,
max_generation_len=generation_config.max_length,
device=device,
)
)
# Enable temporarily disabled values back
generation_config.guidance_scale = original_guidance_scale
generation_config.temperature = original_temperature
return merged_processors
def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
) -> tuple[GenerationConfig, dict]:
generation_config, model_kwargs = super()._prepare_generation_config(
generation_config, use_model_defaults, **kwargs
)
# We allow generation up to max length + max delay pattern
# (will revert back to max length after generation)
generation_config.max_length += max(self.config.delay_pattern)
# Internal flag to indicate CFG that needs to prepare unconditioned input
self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
return generation_config, model_kwargs
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
inputs, input_name, model_kwargs = super()._prepare_model_inputs(
inputs=inputs,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
)
# If CFG is requested we fill in the unconditioned parts
if self._uses_cfg:
unconditioned_inputs = torch.zeros_like(inputs)
inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
if model_kwargs.get("attention_mask", None) is not None:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
return inputs, input_name, model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: dict[str, torch.Tensor],
decoder_start_token_id: torch.Tensor,
device: Optional[torch.device] = None,
) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
decoder_input_ids = decoder_attention_mask = None
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
# We allow generating without preparation (no proper delay) but discourage it
if decoder_input_ids is None or decoder_attention_mask is None:
logger.warning_once(
"In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
)
num_channels = self.config.decoder_config.num_channels
real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
if decoder_input_ids is None:
decoder_input_ids = torch.full(
(real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
)
decoder_attention_mask = torch.ones(
size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
)
# 2. Determine the valid input and what works as mask within the input
delay_mask = decoder_input_ids.long()
valid_input_size = (
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
)
decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
# 3. Overwrite into model kwargs
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
model_kwargs["decoder_delay_mask"] = delay_mask
return decoder_input_ids, model_kwargs
def prepare_inputs_for_generation(
self,
input_ids,
encoder_outputs=None, # Using this to easily get the batch size
decoder_delay_mask=None,
**kwargs,
):
# Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
# Base method handles most things except CFG and the delay pattern mask
model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
# Post processing for CFG and overwriting via delay pattern mask
# 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
model_inputs["decoder_input_ids"] = self.apply_delay_mask(
input_ids, self.config.pad_token_id, decoder_delay_mask
)
# Depending on cache usage we need to pass all or just one
if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
# Be compile friendly
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
# 2. Apply CFG duplication if needed
if self._uses_cfg:
for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
if model_inputs.get(key, None) is not None:
# double first dimension and keep everything else the same
repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
return model_inputs
@staticmethod
def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
if delay_mask is None:
return input_ids
mask_len = min(input_ids.shape[1], delay_mask.shape[1])
valid_mask = delay_mask[:, :mask_len, :]
valid_input = input_ids[:, :mask_len, :]
# Overwrite the respective parts of the input
input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
return input_ids
def _main_generate_loop(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
custom_generate: Optional[str] = None,
**kwargs,
):
# ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
generation_config, model_kwargs = self._prepare_generation_config(
generation_config, use_model_defaults, **kwargs
)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
# 2. Set generation parameters if not already defined
if synced_gpus is None:
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
# 3. Define model inputs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# 4. Define other model kwargs
if "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config._decoder_start_token_tensor,
device=inputs_tensor.device,
)
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
# NOTE: incorrect `input_ids.shape[1]` previously
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
# dynamically overrides this value as it can need more than the last token logits
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
model_kwargs["logits_to_keep"] = 1
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. Prepare the cache.
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
# - different models have a different cache name expected by the model (default = "past_key_values")
# - `max_length`, prepared above, is used to determine the maximum cache length
max_cache_length = generation_config.max_length - 1
if (
inputs_tensor.shape[1] != input_ids_length
and model_input_name == "inputs_embeds"
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]
self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)
# 8. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)
# 9. prepare logits processors and stopping criteria
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
# Set model_kwargs `use_cache` so we can use it later in forward runs
model_kwargs["use_cache"] = generation_config.use_cache
# ******************* taken from main generate function up to calling the different methods *******************
# Prepare inner 2D logic in generation loop
input_ids = input_ids.reshape(-1, input_ids.shape[-1])
# 10. go into different generation modes
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
if generation_config.num_return_sequences > 1:
raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
return self._sample(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
else:
raise ValueError(
"Got incompatible mode for generation, should be one of greedy or sampling. "
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
custom_generate: Optional[str] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
# We expect the initial input ids to be the complete mask (delayed input)
delay_mask = kwargs.get("decoder_input_ids", None)
if delay_mask is not None:
delay_mask = delay_mask.clone()
output = self._main_generate_loop(
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
use_model_defaults=use_model_defaults,
custom_generate=custom_generate,
**kwargs,
)
return_dict_in_generate = not isinstance(output, torch.Tensor)
if return_dict_in_generate:
output_sequences = output.sequences
else:
output_sequences = output
# Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
num_channels = self.config.decoder_config.num_channels
bsz = output_sequences.shape[0] // num_channels
output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
# Apply delay mask
output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
if return_dict_in_generate:
output.sequences = output_sequences
else:
output = output_sequences
return output

View File

@ -0,0 +1,963 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/dia/modular_dia.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_dia.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
from .generation_dia import DiaGenerationMixin
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@auto_docstring
class DiaPreTrainedModel(PreTrainedModel):
config_class = DiaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_static_cache = True
main_input_name = "input_ids"
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DiaRMSNorm):
module.weight.data.fill_(1.0)
class DiaMultiChannelEmbedding(nn.Module):
"""In order to efficiently compute the audio embedding from the 9 different channels,
we vectorize the embedding process by using a single embedding layer and an offset.
Example:
- num_embeds = 4
- vocab_size = 8
- num_channels = 3
We would have offsets = [0, 8, 16]
If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
then tokens = audio_codes + offsets
= [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
This allows us to use a single embedding layer for all channels.
"""
def __init__(self, config: DiaDecoderConfig):
super().__init__()
self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
self.register_buffer("offsets", offsets, persistent=False)
def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
return embeds.sum(dim=2)
class DiaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
@use_kernel_forward_from_hub("RMSNorm")
class DiaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
DiaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class DiaRotaryEmbedding(nn.Module):
def __init__(self, config: DiaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class DiaSelfAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
self.scaling = 1
self.attention_dropout = 0.0
self.is_causal = is_causal
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DiaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.cross_hidden_size = config.cross_hidden_size
self.num_heads = self.config.cross_num_attention_heads
self.num_key_value_heads = self.config.cross_num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = config.cross_head_dim
self.scaling = 1
self.attention_dropout = 0.0
self.is_causal = False
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
else:
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
if past_key_values is not None:
# save all states to the cache
key_states, value_states = past_key_values.cross_attention_cache.update(
key_states,
value_states,
self.layer_idx,
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_values.is_updated[self.layer_idx] = True
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DiaEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DiaEncoderConfig, layer_idx: int):
super().__init__()
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.mlp = DiaMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states
normed_states = self.pre_sa_norm(hidden_states)
self_attn_output, self_attn_weights = self.self_attention(
normed_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + self_attn_output
residual = hidden_states
normed_states = self.post_sa_norm(hidden_states)
mlp_out = self.mlp(normed_states)
hidden_states = residual + mlp_out
return hidden_states, self_attn_weights
class DiaEncoder(DiaPreTrainedModel):
def __init__(self, config: DiaEncoderConfig):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_embeddings = DiaRotaryEmbedding(config)
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[BaseModelOutput, tuple]:
hidden_states = self.embedding(input_ids)
# RoPE
# Note: We expect right padding and hence always generate
# the position ids on the fly to reduce preparation overhead
position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
attention_mask = self._update_full_mask(
attention_mask,
hidden_states,
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
encoder_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
class DiaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
self.cross_attention = DiaCrossAttention(config, layer_idx)
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.mlp = DiaMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
self_attn_cache = past_key_values
if isinstance(self_attn_cache, EncoderDecoderCache):
self_attn_cache = self_attn_cache.self_attention_cache
residual = hidden_states
normed_states = self.pre_sa_norm(hidden_states)
self_attn_output, self_attn_weights = self.self_attention(
normed_states,
position_embeddings,
attention_mask,
# Needs to be an arg in order to function properly
# on inplace operations to be carried (e.g. compile)
self_attn_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + self_attn_output
residual = hidden_states
normed_states = self.pre_ca_norm(hidden_states)
cross_states, cross_attn_weights = self.cross_attention(
normed_states,
encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
**kwargs,
)
hidden_states = residual + cross_states
residual = hidden_states
normed_states = self.pre_mlp_norm(hidden_states)
mlp_out = self.mlp(normed_states)
hidden_states = residual + mlp_out
return hidden_states, self_attn_weights, cross_attn_weights
class DiaDecoder(DiaPreTrainedModel):
"""Transformer Decoder Stack using DenseGeneral."""
def __init__(self, config: DiaDecoderConfig):
super().__init__(config)
self.num_channels = config.num_channels
self.vocab_size = config.vocab_size
self.embeddings = DiaMultiChannelEmbedding(config)
self.rotary_embeddings = DiaRotaryEmbedding(config)
self.layers = nn.ModuleList(
[DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
[What are input IDs?](../glossary#input-ids)
"""
batch_size, seq_length = input_ids.size()[:-1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
)
if position_ids is None:
position_ids = cache_position[None, :]
# RoPE
hidden_states = self.embeddings(input_ids)
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
hidden_states.shape[:2],
hidden_states,
)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
position_embeddings,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns = all_self_attns + (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
return encoder_attention_mask
@auto_docstring(
custom_intro="""
The bare Dia model outputting raw hidden-states without any specific head on top.
"""
)
class DiaModel(DiaPreTrainedModel):
def __init__(self, config: DiaConfig):
super().__init__(config)
self.config = config
self.encoder = DiaEncoder(config.encoder_config)
self.decoder = DiaDecoder(config.decoder_config)
self.post_init()
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_position_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
tened audio logits which are used to calculate the loss.
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
Dia to calculate embeddings and subsequent steps more efficiently.
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
[`DiaProcessor.__call__`] for more details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings.
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
[What are position IDs?](../glossary#position-ids)
"""
if input_ids is None and encoder_outputs is None:
raise ValueError(
"You should either provide text ids or the cached text encodings. Neither has been found."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if self.is_gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
elif not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# On default we initialize the decoder with bos tokens if nothing has been provided
bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
if decoder_input_ids is None:
decoder_input_ids = torch.full(
size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
)
# Ensure 3D
if decoder_input_ids.ndim == 2:
decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
position_ids=decoder_position_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs[0],
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@auto_docstring(
custom_intro="""
The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
"""
)
class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
base_model_prefix = "model"
def __init__(self, config: DiaConfig):
super().__init__(config)
self.config = config
self.model = DiaModel(config)
self.num_channels = config.decoder_config.num_channels
self.vocab_size = config.decoder_config.vocab_size
self.logits_dense = nn.Linear(
config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
)
self.loss_type = "ForMaskedLM"
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_position_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, Seq2SeqLMOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
tened audio logits which are used to calculate the loss.
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
Dia to calculate embeddings and subsequent steps more efficiently.
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
[`DiaProcessor.__call__`] for more details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings.
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
[What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in
`[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
are ignored (masked).
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_position_ids=decoder_position_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
last_hidden_state = outputs[0]
batch_size = last_hidden_state.shape[0]
# 3D <-> 2D makes it necessary to prioritize channel dim
audio_logits = (
self.logits_dense(last_hidden_state)
.view((batch_size, -1, self.num_channels, self.vocab_size))
.transpose(1, 2)
.contiguous()
.view(batch_size * self.num_channels, -1, self.vocab_size)
)
loss = None
if labels is not None:
loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
return Seq2SeqLMOutput(
loss=loss,
logits=audio_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]

View File

@ -0,0 +1,789 @@
# coding=utf-8
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Dia model."""
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...cache_utils import DynamicCache, EncoderDecoderCache
from ...masking_utils import create_causal_mask
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from ..llama.modeling_llama import (
LlamaAttention,
LlamaRMSNorm,
LlamaRotaryEmbedding,
eager_attention_forward,
)
from ..phi3.modeling_phi3 import Phi3MLP
from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
from .generation_dia import DiaGenerationMixin
if is_torch_flex_attn_available():
from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@auto_docstring
class DiaPreTrainedModel(PreTrainedModel):
config_class = DiaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_static_cache = True
main_input_name = "input_ids"
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DiaRMSNorm):
module.weight.data.fill_(1.0)
class DiaMultiChannelEmbedding(nn.Module):
"""In order to efficiently compute the audio embedding from the 9 different channels,
we vectorize the embedding process by using a single embedding layer and an offset.
Example:
- num_embeds = 4
- vocab_size = 8
- num_channels = 3
We would have offsets = [0, 8, 16]
If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
then tokens = audio_codes + offsets
= [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
This allows us to use a single embedding layer for all channels.
"""
def __init__(self, config: DiaDecoderConfig):
super().__init__()
self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
self.register_buffer("offsets", offsets, persistent=False)
def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
return embeds.sum(dim=2)
class DiaMLP(Phi3MLP):
pass
class DiaRMSNorm(LlamaRMSNorm):
pass
class DiaRotaryEmbedding(LlamaRotaryEmbedding):
pass
class DiaSelfAttention(LlamaAttention, nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
nn.Module.__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
self.scaling = 1
self.attention_dropout = 0.0
self.is_causal = is_causal
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
class DiaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.cross_hidden_size = config.cross_hidden_size
self.num_heads = self.config.cross_num_attention_heads
self.num_key_value_heads = self.config.cross_num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.head_dim = config.cross_head_dim
self.scaling = 1
self.attention_dropout = 0.0
self.is_causal = False
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
else:
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
if past_key_values is not None:
# save all states to the cache
key_states, value_states = past_key_values.cross_attention_cache.update(
key_states,
value_states,
self.layer_idx,
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
past_key_values.is_updated[self.layer_idx] = True
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class DiaEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DiaEncoderConfig, layer_idx: int):
super().__init__()
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.mlp = DiaMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
residual = hidden_states
normed_states = self.pre_sa_norm(hidden_states)
self_attn_output, self_attn_weights = self.self_attention(
normed_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + self_attn_output
residual = hidden_states
normed_states = self.post_sa_norm(hidden_states)
mlp_out = self.mlp(normed_states)
hidden_states = residual + mlp_out
return hidden_states, self_attn_weights
class DiaEncoder(DiaPreTrainedModel):
def __init__(self, config: DiaEncoderConfig):
super().__init__(config)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.rotary_embeddings = DiaRotaryEmbedding(config)
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[BaseModelOutput, tuple]:
hidden_states = self.embedding(input_ids)
# RoPE
# Note: We expect right padding and hence always generate
# the position ids on the fly to reduce preparation overhead
position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
attention_mask = self._update_full_mask(
attention_mask,
hidden_states,
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
encoder_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(
self,
attention_mask: Union[torch.Tensor, None],
inputs_embeds: torch.Tensor,
):
if attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
elif self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
return attention_mask
class DiaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
self.cross_attention = DiaCrossAttention(config, layer_idx)
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
self.mlp = DiaMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
self_attn_cache = past_key_values
if isinstance(self_attn_cache, EncoderDecoderCache):
self_attn_cache = self_attn_cache.self_attention_cache
residual = hidden_states
normed_states = self.pre_sa_norm(hidden_states)
self_attn_output, self_attn_weights = self.self_attention(
normed_states,
position_embeddings,
attention_mask,
# Needs to be an arg in order to function properly
# on inplace operations to be carried (e.g. compile)
self_attn_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + self_attn_output
residual = hidden_states
normed_states = self.pre_ca_norm(hidden_states)
cross_states, cross_attn_weights = self.cross_attention(
normed_states,
encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
**kwargs,
)
hidden_states = residual + cross_states
residual = hidden_states
normed_states = self.pre_mlp_norm(hidden_states)
mlp_out = self.mlp(normed_states)
hidden_states = residual + mlp_out
return hidden_states, self_attn_weights, cross_attn_weights
class DiaDecoder(DiaPreTrainedModel):
"""Transformer Decoder Stack using DenseGeneral."""
def __init__(self, config: DiaDecoderConfig):
super().__init__(config)
self.num_channels = config.num_channels
self.vocab_size = config.vocab_size
self.embeddings = DiaMultiChannelEmbedding(config)
self.rotary_embeddings = DiaRotaryEmbedding(config)
self.layers = nn.ModuleList(
[DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
[What are input IDs?](../glossary#input-ids)
"""
batch_size, seq_length = input_ids.size()[:-1]
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
)
if position_ids is None:
position_ids = cache_position[None, :]
# RoPE
hidden_states = self.embeddings(input_ids)
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
attention_mask = create_causal_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
)
encoder_attention_mask = self._update_cross_attn_mask(
encoder_hidden_states,
encoder_attention_mask,
hidden_states.shape[:2],
hidden_states,
)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
position_embeddings,
attention_mask,
encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
cache_position=cache_position,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns = all_self_attns + (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
def _update_cross_attn_mask(
self,
encoder_hidden_states: Union[torch.Tensor, None],
encoder_attention_mask: Union[torch.Tensor, None],
input_shape: torch.Size,
inputs_embeds: torch.Tensor,
):
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
if self.config._attn_implementation == "flash_attention_2":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa":
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
encoder_attention_mask,
inputs_embeds.dtype,
tgt_len=input_shape[-1],
)
elif self.config._attn_implementation == "flex_attention":
if isinstance(encoder_attention_mask, torch.Tensor):
encoder_attention_mask = make_flex_block_causal_mask(
encoder_attention_mask,
query_length=input_shape[-1],
is_causal=False,
)
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
return encoder_attention_mask
@auto_docstring(
custom_intro="""
The bare Dia model outputting raw hidden-states without any specific head on top.
"""
)
class DiaModel(DiaPreTrainedModel):
def __init__(self, config: DiaConfig):
super().__init__(config)
self.config = config
self.encoder = DiaEncoder(config.encoder_config)
self.decoder = DiaDecoder(config.decoder_config)
self.post_init()
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_position_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
tened audio logits which are used to calculate the loss.
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
Dia to calculate embeddings and subsequent steps more efficiently.
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
[`DiaProcessor.__call__`] for more details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings.
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
[What are position IDs?](../glossary#position-ids)
"""
if input_ids is None and encoder_outputs is None:
raise ValueError(
"You should either provide text ids or the cached text encodings. Neither has been found."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if self.is_gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
elif not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# On default we initialize the decoder with bos tokens if nothing has been provided
bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
if decoder_input_ids is None:
decoder_input_ids = torch.full(
size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
)
# Ensure 3D
if decoder_input_ids.ndim == 2:
decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
position_ids=decoder_position_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs[0],
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@auto_docstring(
custom_intro="""
The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
"""
)
class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
base_model_prefix = "model"
def __init__(self, config: DiaConfig):
super().__init__(config)
self.config = config
self.model = DiaModel(config)
self.num_channels = config.decoder_config.num_channels
self.vocab_size = config.decoder_config.vocab_size
self.logits_dense = nn.Linear(
config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
)
self.loss_type = "ForMaskedLM"
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_position_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[tuple, Seq2SeqLMOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
tened audio logits which are used to calculate the loss.
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
Dia to calculate embeddings and subsequent steps more efficiently.
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
[`DiaProcessor.__call__`] for more details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
Indices of positions of each input sequence tokens in the position embeddings.
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
[What are position IDs?](../glossary#position-ids)
labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in
`[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
are ignored (masked).
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_position_ids=decoder_position_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
last_hidden_state = outputs[0]
batch_size = last_hidden_state.shape[0]
# 3D <-> 2D makes it necessary to prioritize channel dim
audio_logits = (
self.logits_dense(last_hidden_state)
.view((batch_size, -1, self.num_channels, self.vocab_size))
.transpose(1, 2)
.contiguous()
.view(batch_size * self.num_channels, -1, self.vocab_size)
)
loss = None
if labels is not None:
loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
return Seq2SeqLMOutput(
loss=loss,
logits=audio_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]

View File

@ -0,0 +1,484 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor class for Dia"""
import math
from pathlib import Path
from typing import Optional, Union
from ...audio_utils import AudioInput, make_list_of_audio
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...utils import is_soundfile_available, is_torch_available
if is_torch_available():
import torch
if is_soundfile_available():
import soundfile as sf
class DiaAudioKwargs(AudioKwargs, total=False):
bos_token_id: int
eos_token_id: int
pad_token_id: int
delay_pattern: list[int]
generation: bool
class DiaProcessorKwargs(ProcessingKwargs, total=False):
audio_kwargs: DiaAudioKwargs
_defaults = {
"text_kwargs": {
"padding": True,
"padding_side": "right",
"add_special_tokens": False,
},
"audio_kwargs": {
"eos_token_id": 1024,
"pad_token_id": 1025,
"bos_token_id": 1026,
"delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
"generation": True,
"sampling_rate": 44100,
},
"common_kwargs": {"return_tensors": "pt"},
}
class DiaProcessor(ProcessorMixin):
r"""
Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
information.
Args:
feature_extractor (`DiaFeatureExtractor`):
An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`DiaTokenizer`):
An instance of [`DiaTokenizer`]. The tokenizer is a required input.
audio_tokenizer (`DacModel`):
An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
"""
feature_extractor_class = "DiaFeatureExtractor"
tokenizer_class = "DiaTokenizer"
audio_tokenizer_class = "DacModel"
def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
@property
def model_input_names(self):
"""
We no longer pass the raw audio values but the codebooks encoded by the `audio_tokenizer`.
Conventions may differ between audio models due to architectural choices.
"""
tokenizer_input_names = self.tokenizer.model_input_names
audio_tokenizer_input_names = ["decoder_input_ids", "decoder_attention_mask"]
return list(dict.fromkeys(tokenizer_input_names + audio_tokenizer_input_names))
def __call__(
self,
text: Union[str, list[str]],
audio: Optional[AudioInput] = None,
output_labels: Optional[bool] = False,
**kwargs: Unpack[DiaProcessorKwargs],
):
"""
Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
to the docstring of the above methods for more information.
"""
if not is_torch_available():
raise ValueError(
"The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
"find it in your environment. You can install torch via `pip install torch`."
)
if text is None:
raise ValueError("You need to specify the `text` input to process.")
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
text_kwargs = output_kwargs["text_kwargs"]
audio_kwargs = output_kwargs["audio_kwargs"]
common_kwargs = output_kwargs["common_kwargs"]
return_tensors = common_kwargs.pop("return_tensors", None)
if return_tensors != "pt":
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
data = {}
# Text
if isinstance(text, str):
text = [text]
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
encodings = self.tokenizer(text, **text_kwargs)
data.update(encodings)
# Audio
delay_pattern = audio_kwargs.pop("delay_pattern", None)
audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
generation = audio_kwargs.pop("generation", True)
if (
audio_bos_token_id is None
or audio_eos_token_id is None
or audio_pad_token_id is None
or delay_pattern is None
):
raise ValueError(
"To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
"`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
)
if generation and output_labels:
raise ValueError(
f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
)
batch_size = data["input_ids"].shape[0]
num_channels = len(delay_pattern)
max_delay = max(delay_pattern)
# Voice cloning generation / general training
if audio is not None:
audio = make_list_of_audio(audio)
input_audios = self.feature_extractor(audio, **audio_kwargs)
compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
decoder_input_ids = []
decoder_attention_mask = []
# TODO: dac with batching is currently broken, but non-batch is working
# refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
# get current length with hop length in mind (as if it were sampled as a single audio)
base_pad_len = self.feature_extractor.hop_length
current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
encoded_sequence_len = current_audio_len // compression_rate
padding_len = max_encoded_sequence_len - encoded_sequence_len
# compute non-padded forward pass; one extra bos (and eos if training) is added
with torch.no_grad():
audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
if not generation:
input_ids = torch.nn.functional.pad(
input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
)
# apply padding
# +1 for the bos within the real sequence
input_ids = torch.nn.functional.pad(
input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
)
num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
num_valid_inputs += 0 if generation else 1 # eos if training
attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
decoder_input_ids.append(input_ids)
decoder_attention_mask.append(attention_mask)
decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
# TTS generation
elif generation:
# all bos to start with TTS
decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
# we preemptively add the delay
decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
else:
raise ValueError("If you try to train, you should provide audio data as well.")
if batch_size != decoder_input_ids.shape[0]:
raise ValueError(
f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
f"audio samples = {decoder_input_ids.shape[0]} instead."
)
# prepare shift indices per delay
max_seq_len = decoder_attention_mask.shape[-1]
max_audio_len = max_seq_len - max_delay
precomputed_idx = self.build_indices(
bsz=batch_size,
seq_len=max_seq_len,
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=False,
)
# create delay pattern input
# the pad token will be used for masking which input is valid for prediction during generation
prefill = torch.full(
(batch_size, max_seq_len, num_channels),
fill_value=audio_pad_token_id,
dtype=torch.int,
)
prefill[:, :max_audio_len] = decoder_input_ids
delayed_decoder_input_ids = self.apply_audio_delay(
audio=prefill,
pad_token_id=audio_pad_token_id,
bos_token_id=audio_bos_token_id,
precomputed_idx=precomputed_idx,
)
data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
if output_labels:
# Base idea is to shift on the sequence dim
labels = data["decoder_input_ids"].clone()[:, 1:]
labels[labels == audio_pad_token_id] = -100
labels[labels == audio_bos_token_id] = -100
data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
return BatchFeature(data=data, tensor_type=return_tensors)
def batch_decode(
self,
decoder_input_ids: "torch.Tensor",
audio_prompt_len: Optional[int] = None,
**kwargs: Unpack[DiaProcessorKwargs],
) -> list["torch.Tensor"]:
"""
Decodes a batch of audio codebook sequences into their respective audio waveforms via the
`audio_tokenizer`. See [`~DacModel.decode`] for more information.
Args:
decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
"""
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
delay_pattern = audio_kwargs.pop("delay_pattern", None)
audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
raise ValueError(
"To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
"and `delay_pattern`. You may have accidentally overwritten one of those."
)
# either decode the whole audio sequence or only the generated parts
if audio_prompt_len is not None:
audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
else:
start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
# -1 for the eos token
end_of_generation_idx = (
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
)
# revert delay
bsz, seq_len, num_channels = decoder_input_ids.shape
precomputed_idx = self.build_indices(
bsz=bsz,
seq_len=seq_len,
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=True,
)
output_sequences = self.apply_audio_delay(
audio=decoder_input_ids,
# We do not care about these values as we cut them out
# with `start_of_generation_idx` and `end_of_generation_idx`
pad_token_id=-1,
bos_token_id=-1,
precomputed_idx=precomputed_idx,
).transpose(1, 2)
# retrieve the correct sequences each
audios = []
# TODO: see above, dac doesn't work in batches yet
with torch.no_grad():
for i in range(start_of_generation_idx.shape[0]):
output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
output_i = output_i.to(self.audio_tokenizer.device)
audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
audios.append(audio_i)
return audios
def decode(
self,
decoder_input_ids: "torch.Tensor",
audio_prompt_len: Optional[int] = None,
**kwargs: Unpack[DiaProcessorKwargs],
) -> "torch.Tensor":
"""
Decodes a single sequence of audio codebooks into the respective audio waveform via the
`audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
"""
if decoder_input_ids.shape[0] != 1:
raise ValueError(
f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
)
return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
def get_audio_prompt_len(
self,
decoder_attention_mask: "torch.Tensor",
**kwargs: Unpack[DiaProcessorKwargs],
) -> int:
"""Utility function to get the audio prompt length."""
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
delay_pattern = audio_kwargs.pop("delay_pattern", None)
if delay_pattern is None:
raise ValueError(
"To enable the utility of retrieving the prompt length for Dia, we need the "
"`delay_pattern`. You may have accidentally overwritten this."
)
return decoder_attention_mask.shape[1] - max(delay_pattern)
# Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
def save_audio(
self,
audio: AudioInput,
saving_path: Union[str, Path, list[Union[str, Path]]],
**kwargs: Unpack[DiaProcessorKwargs],
):
# TODO: @eustlb, this should be in AudioProcessor
if not is_soundfile_available():
raise ImportError("Please install `soundfile` to save audio files.")
# ensure correct audio input
audio = make_list_of_audio(audio)
# ensure correct saving path
if isinstance(saving_path, (str, Path)):
saving_path = [saving_path]
elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
raise ValueError("Invalid input path. Please provide a string, or a list of strings")
if len(audio) != len(saving_path):
raise ValueError("The number of audio and saving paths must be the same")
output_kwargs = self._merge_kwargs(
DiaProcessorKwargs,
**kwargs,
)
audio_kwargs = output_kwargs["audio_kwargs"]
sampling_rate = audio_kwargs["sampling_rate"]
for audio_value, p in zip(audio, saving_path):
if isinstance(audio_value, torch.Tensor):
audio_value = audio_value.cpu().float().numpy()
sf.write(p, audio_value, sampling_rate)
@staticmethod
def build_indices(
bsz: int,
seq_len: int,
num_channels: int,
delay_pattern: list[int],
revert: bool = False,
) -> tuple["torch.Tensor", "torch.Tensor"]:
"""
Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
"""
delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
# (0..seq_len-1)
sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
# + or - delay depending if we delay or revert the delay
if not revert:
sequence_idx = sequence_idx - delay_array[None, None, :]
else:
sequence_idx = sequence_idx + delay_array[None, None, :]
# if delay goes over the range we clamp back to valid values
valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
all_idx = torch.stack(
[batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
dim=1,
).long()
return sequence_idx, all_idx
@staticmethod
def apply_audio_delay(
audio: "torch.Tensor",
pad_token_id: int,
bos_token_id: int,
precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
) -> "torch.Tensor":
"""
Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
Args:
audio: audio tokens of shape [bsz, seq_len, num_channels]
pad_token_id: the PAD token
bos_token_id: the BOS token
precomputed_idx: from `build_indices`
Returns:
final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
"""
# Move everything to the same device
device = audio.device
sequence_idx, all_idx = precomputed_idx
sequence_idx = sequence_idx.to(device)
all_idx = all_idx.to(device)
# Gather per precomputed indices
batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
# Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
mask_bos = sequence_idx < 0
mask_pad = sequence_idx >= audio.shape[1]
final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
return final_audio
__all__ = ["DiaProcessor"]

View File

@ -0,0 +1,118 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization class for Dia."""
from typing import Optional
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
class DiaTokenizer(PreTrainedTokenizer):
"""
Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
unk_token (`str`, *optional*, defaults to `"<pad>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
max_length (`int`, *optional*, defaults to 1024):
The maximum length of the sequences when encoding. Sequences longer than this will be truncated.
offset (`int`, *optional*, defaults to 0):
The offset of the tokenizer.
"""
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
pad_token: Optional[str] = "<pad>",
unk_token: Optional[str] = "<pad>",
max_length: Optional[int] = 1024,
offset: int = 0,
**kwargs,
):
# We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well
pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token
unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token
self._utf_vocab_size = 2**8 # utf is 8 bits
self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")}
self.offset = offset
super().__init__(
unk_token=unk_token,
pad_token=pad_token,
max_length=max_length,
**kwargs,
)
@property
def vocab_size(self):
return self._utf_vocab_size
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text: str) -> list[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
tokens = [chr(i) for i in text.encode("utf-8")]
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if len(token) != 1:
token_id = None
else:
token_id = ord(token) + self.offset
return token_id
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = chr(index - self.offset)
return token
def convert_tokens_to_string(self, tokens: list[str]) -> str:
"""Converts a sequence of tokens (string) in a single string."""
bstring = b""
for token in tokens:
if token in self.added_tokens_decoder:
added_token_obj = self.added_tokens_decoder[token]
tok_string = str(added_token_obj).encode("utf-8")
elif token in self.added_tokens_encoder:
tok_string = token.encode("utf-8")
else:
tok_string = token.encode("utf-8") # Assume general string token
bstring += tok_string
string = bstring.decode("utf-8", errors="ignore")
return string
# No vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
return ()
__all__ = ["DiaTokenizer"]

View File

@ -80,15 +80,21 @@ class TextToAudioPipeline(Pipeline):
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
"""
# Introducing the processor at load time for new behaviour
_load_processor = True
_pipeline_calls_generate = True
# Make sure the docstring is updated when the default generation config is changed
_default_generation_config = GenerationConfig(
max_new_tokens=256,
)
def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs):
super().__init__(*args, **kwargs)
# Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time
self.no_processor = no_processor
if self.framework == "tf":
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
@ -117,6 +123,10 @@ class TextToAudioPipeline(Pipeline):
if sampling_rate is not None:
self.sampling_rate = sampling_rate
# last fallback to get the sampling rate based on processor
if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"):
self.sampling_rate = self.processor.feature_extractor.sampling_rate
def preprocess(self, text, **kwargs):
if isinstance(text, str):
text = [text]
@ -136,7 +146,8 @@ class TextToAudioPipeline(Pipeline):
kwargs = new_kwargs
output = self.tokenizer(text, **kwargs, return_tensors="pt")
preprocessor = self.tokenizer if self.no_processor else self.processor
output = preprocessor(text, **kwargs, return_tensors="pt")
return output
@ -228,12 +239,21 @@ class TextToAudioPipeline(Pipeline):
return preprocess_params, params, postprocess_params
def postprocess(self, waveform):
def postprocess(self, audio):
output_dict = {}
if isinstance(waveform, dict):
waveform = waveform["waveform"]
elif isinstance(waveform, tuple):
waveform = waveform[0]
# We directly get the waveform
if self.no_processor:
if isinstance(audio, dict):
waveform = audio["waveform"]
elif isinstance(audio, tuple):
waveform = audio[0]
else:
waveform = audio
# Or we need to postprocess to get the waveform
else:
waveform = self.processor.decode(audio)
output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy()
output_dict["sampling_rate"] = self.sampling_rate

View File

@ -49,6 +49,7 @@ from .tokenization_utils_base import (
TruncationStrategy,
)
from .utils import (
AUDIO_TOKENIZER_NAME,
CHAT_TEMPLATE_DIR,
CHAT_TEMPLATE_FILE,
LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
@ -61,12 +62,17 @@ from .utils import (
download_url,
is_offline_mode,
is_remote_url,
is_torch_available,
list_repo_templates,
logging,
)
from .utils.deprecation import deprecate_kwarg
if is_torch_available():
from .modeling_utils import PreTrainedAudioTokenizerBase
logger = logging.get_logger(__name__)
# Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
@ -499,7 +505,7 @@ class ProcessorMixin(PushToHubMixin):
"""
attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"]
optional_attributes = ["chat_template", "audio_tokenizer"]
optional_call_args: list[str] = []
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
@ -511,7 +517,19 @@ class ProcessorMixin(PushToHubMixin):
# First, extract optional attributes from kwargs if present
# Optional attributes can never be positional arguments
for optional_attribute in self.optional_attributes:
setattr(self, optional_attribute, kwargs.pop(optional_attribute, None))
optional_attribute_value = kwargs.pop(optional_attribute, None)
setattr(self, optional_attribute, optional_attribute_value)
# Check audio tokenizer for its class but do not treat it as attr to avoid saving weights
if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None:
proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value)
if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)):
raise ValueError(
f"Tried to use `{proper_class}` for audio tokenization. However, this class is not"
" registered for audio tokenization."
)
# Sanitize args and kwargs
for key in kwargs:
if key not in self.attributes:
@ -530,21 +548,30 @@ class ProcessorMixin(PushToHubMixin):
# Check each arg is of the proper class (this will also catch a user initializing in the wrong order)
for attribute_name, arg in kwargs.items():
class_name = getattr(self, f"{attribute_name}_class")
# Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
if isinstance(class_name, tuple):
proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
else:
proper_class = self.get_possibly_dynamic_module(class_name)
if not isinstance(arg, proper_class):
raise TypeError(
f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected."
)
self.check_argument_for_proper_class(attribute_name, arg)
setattr(self, attribute_name, arg)
def check_argument_for_proper_class(self, argument_name, argument):
"""
Checks the passed argument's class against the expected transformers class. In case of an unexpected
mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class
is returned.
"""
class_name = getattr(self, f"{argument_name}_class")
# Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
if isinstance(class_name, tuple):
proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
else:
proper_class = self.get_possibly_dynamic_module(class_name)
if not isinstance(argument, proper_class):
raise TypeError(
f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected."
)
return proper_class
def to_dict(self) -> dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
@ -577,6 +604,8 @@ class ProcessorMixin(PushToHubMixin):
del output["feature_extractor"]
if "chat_template" in output:
del output["chat_template"]
if "audio_tokenizer" in output:
del output["audio_tokenizer"]
# Some attributes have different names but containing objects that are not simple strings
output = {
@ -695,6 +724,7 @@ class ProcessorMixin(PushToHubMixin):
save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE
) # Legacy filename
chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME)
processor_dict = self.to_dict()
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
@ -737,6 +767,19 @@ class ProcessorMixin(PushToHubMixin):
"separate files using the `save_jinja_files` argument."
)
if self.audio_tokenizer is not None:
audio_tokenizer_class = self.audio_tokenizer.__class__.__name__
audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path
audio_tokenizer_dict = {
"audio_tokenizer_class": audio_tokenizer_class,
"audio_tokenizer_name_or_path": audio_tokenizer_name_or_path,
}
audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n"
with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer:
writer.write(audio_tokenizer_json)
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
# `auto_map` is not specified.
if set(processor_dict.keys()) != {"processor_class"}:
@ -774,6 +817,9 @@ class ProcessorMixin(PushToHubMixin):
Returns:
`tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object.
"""
# holding a copy for optionally loading the audio tokenizer (if available)
audio_tokenizer_kwargs = copy.deepcopy(kwargs)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None)
@ -803,16 +849,18 @@ class ProcessorMixin(PushToHubMixin):
resolved_additional_chat_template_files = {}
if os.path.isfile(pretrained_model_name_or_path):
resolved_processor_file = pretrained_model_name_or_path
# can't load chat-template when given a file as pretrained_model_name_or_path
# can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path
resolved_chat_template_file = None
resolved_raw_chat_template_file = None
resolved_audio_tokenizer_file = None
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
processor_file = pretrained_model_name_or_path
resolved_processor_file = download_url(pretrained_model_name_or_path)
# can't load chat-template when given a file url as pretrained_model_name_or_path
# can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path
resolved_chat_template_file = None
resolved_raw_chat_template_file = None
resolved_audio_tokenizer_file = None
else:
if is_local:
template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
@ -899,6 +947,21 @@ class ProcessorMixin(PushToHubMixin):
)
for template_name, template_file in additional_chat_template_files.items()
}
resolved_audio_tokenizer_file = cached_file(
pretrained_model_name_or_path,
AUDIO_TOKENIZER_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
)
except OSError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
@ -939,6 +1002,22 @@ class ProcessorMixin(PushToHubMixin):
if chat_templates:
kwargs["chat_template"] = chat_templates
# Same as chat template, adding as kwarg after loading the model
audio_tokenizer = None
if resolved_audio_tokenizer_file is not None:
with open(resolved_audio_tokenizer_file, "r", encoding="utf-8") as reader:
# The json contains the references we need to init the correct model
audio_tokenizer_references = json.load(reader)
audio_tokenizer_class = cls.get_possibly_dynamic_module(
audio_tokenizer_references["audio_tokenizer_class"]
)
audio_tokenizer_path = audio_tokenizer_references["audio_tokenizer_name_or_path"]
audio_tokenizer = audio_tokenizer_class.from_pretrained(audio_tokenizer_path, **audio_tokenizer_kwargs)
if audio_tokenizer is not None:
kwargs["audio_tokenizer"] = audio_tokenizer
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
@ -947,7 +1026,9 @@ class ProcessorMixin(PushToHubMixin):
# In any case we need to pass `chat_template` if it is available
processor_dict = {}
if "chat_template" in kwargs:
processor_dict = {"chat_template": kwargs.pop("chat_template")}
processor_dict["chat_template"] = kwargs.pop("chat_template")
if "audio_tokenizer" in kwargs:
processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer")
return processor_dict, kwargs
try:
@ -972,6 +1053,8 @@ class ProcessorMixin(PushToHubMixin):
if "chat_template" in kwargs:
processor_dict["chat_template"] = kwargs.pop("chat_template")
if "audio_tokenizer" in kwargs:
processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer")
return processor_dict, kwargs
@ -1276,6 +1359,7 @@ class ProcessorMixin(PushToHubMixin):
attribute_class = cls.get_possibly_dynamic_module(class_name)
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
return args
@staticmethod
@ -1287,6 +1371,7 @@ class ProcessorMixin(PushToHubMixin):
transformers_module.VIDEO_PROCESSOR_MAPPING,
transformers_module.TOKENIZER_MAPPING,
transformers_module.FEATURE_EXTRACTOR_MAPPING,
transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
]
for lookup_location in lookup_locations:
for custom_class in lookup_location._extra_content.values():

View File

@ -292,6 +292,7 @@ CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json"
PROCESSOR_NAME = "processor_config.json"
GENERATION_CONFIG_NAME = "generation_config.json"
MODEL_CARD_NAME = "modelcard.json"

View File

@ -56,7 +56,12 @@ if is_torch_available():
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
from transformers.generation.logits_process import (
BarkEosPrioritizerLogitsProcessor,
DiaClassifierFreeGuidanceLogitsProcessor,
DiaEOSChannelFilterLogitsProcessor,
DiaEOSDelayPatternLogitsProcessor,
)
@require_torch
@ -1211,3 +1216,145 @@ class LogitsProcessorTest(unittest.TestCase):
)
)
self.assertTrue(is_close)
def test_dia_classifier_free_guidance(self):
input_ids = torch.LongTensor([[0]])
logits_uncond = torch.tensor([[1.0, 0, 1.5]])
logits_cond = torch.tensor([[1.0, 1.0, 1.0]])
# base cfg with conditioned as center
cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5)
out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
res = logits_cond + 1.5 * (logits_cond - logits_uncond)
self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
# additional top k (on cond logits)
cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5, guidance_top_k=1)
out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
res = logits_cond + 1.5 * (logits_cond - logits_uncond)
mask = res == res.max()
res = logits_cond.clone()
res[~mask.bool()] = -float("inf")
self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
def test_dia_channel_filter(self):
eos = 2
bsz, channels, vocab = 2, 2, 4
input_ids = torch.LongTensor([[0]])
logits = torch.zeros(size=(bsz, channels, vocab)).view(bsz * channels, vocab)
logits[0, eos] = 1 # Eos max (forced)
logits[1, eos] = 1 # Eos max (forced) but not channel 0
channel_filter = DiaEOSChannelFilterLogitsProcessor(num_channels=channels, eos_token_id=eos)
out = channel_filter(input_ids, logits).view(bsz, channels, vocab)
for i in range(vocab):
if i > eos:
# special tokens are not to be predicted
self.assertTrue((out[:, :, i] == -float("inf")).all())
elif i == eos:
# Eos forced on channel 0
self.assertTrue(out[0, 0, i] == 1)
# Eos suppressed on everything else (even if max before)
self.assertTrue(out[0, 1, i] == -float("inf"))
self.assertTrue((out[1, :, i] == -float("inf")).all())
else:
# Eos forced on channel 0
self.assertTrue(out[0, 0, i] == -float("inf"))
# previous values
self.assertTrue(out[0, 1, i] == 0)
self.assertTrue((out[1, :, i] == 0).all())
def test_dia_delay_pattern(self):
def check_eos_logits(out, logits, batch, channel, eos):
for i in range(vocab):
if i == eos:
self.assertTrue(out[batch, channel, i] == 0)
else:
self.assertTrue(out[batch, channel, i] == -float("inf"))
for c in range(channel):
if c != channel:
self.assertTrue((out[batch, c] == logits[batch, c]).all())
eos = 2
delay_pattern = [0, 2, 3]
max_generation_len = 10
bsz, channels, vocab = 2, 3, 4
input_ids = torch.LongTensor([[0]])
logits = torch.zeros(size=(bsz, channels, vocab))
# Ensure that argmax can not result in eos
logits[:, :, eos] = -1
delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
)
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
# Nothing should happen except for init of some attributes
self.assertTrue((out == logits).all())
self.assertTrue((~delay_pattern_processor.active_batches).all())
self.assertTrue(
(delay_pattern_processor.delay_pattern == torch.tensor([delay_pattern for _ in range(bsz)])).all()
)
# Make first batch end
logits[0, 0, eos] = 1
# Go through the complete delay pattern
for i in range(max(delay_pattern) + 1):
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
# no delay should kick in
if i == 1:
self.assertTrue((out == logits).all())
else:
j = i if i == 0 else i - 1
check_eos_logits(out=out, logits=logits, batch=0, channel=j, eos=eos)
self.assertTrue((out[1] == logits[1]).all())
self.assertTrue(delay_pattern_processor.active_batches[0])
self.assertFalse(delay_pattern_processor.active_batches[1])
self.assertTrue(
(
delay_pattern_processor.delay_pattern[0]
== torch.tensor([delay - (i + 1) for delay in delay_pattern])
).all()
)
self.assertTrue((delay_pattern_processor.delay_pattern[1] == torch.tensor(delay_pattern)).all())
# Make second batch end
logits[1, 0, eos] = 1
# Just to check if other batches could work
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
self.assertTrue((out[0] == logits[0]).all())
self.assertTrue(delay_pattern_processor.active_batches.all())
self.assertTrue(
(delay_pattern_processor.delay_pattern[0] == torch.tensor([delay - 5 for delay in delay_pattern])).all()
)
self.assertTrue(
(delay_pattern_processor.delay_pattern[1] == torch.tensor([delay - 1 for delay in delay_pattern])).all()
)
# Last check on max generation length reached (with delay in mind until last channel produces eos)
input_ids = torch.LongTensor([[0] * (max_generation_len - max(delay_pattern) - 1)])
delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
)
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
check_eos_logits(out=out, logits=logits, batch=0, channel=0, eos=eos)
check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos)
self.assertTrue(delay_pattern_processor.active_batches.all())
self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all())

View File

@ -26,6 +26,7 @@ import transformers
from transformers import (
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
PROCESSOR_MAPPING,
TOKENIZER_MAPPING,
AutoConfig,
@ -265,6 +266,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_processor_conflict(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
@ -317,6 +320,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_from_pretrained_dynamic_processor_with_extra_attributes(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
@ -356,6 +361,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_dynamic_processor_with_specific_dynamic_subcomponents(self):
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
@ -390,6 +397,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
del TOKENIZER_MAPPING._extra_content[CustomConfig]
if CustomConfig in PROCESSOR_MAPPING._extra_content:
del PROCESSOR_MAPPING._extra_content[CustomConfig]
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
def test_auto_processor_creates_tokenizer(self):
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert")

View File

View File

@ -0,0 +1,231 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Dia feature extractor."""
import itertools
import random
import unittest
import numpy as np
from transformers import DiaFeatureExtractor
from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_torch_available():
import torch
global_rng = random.Random()
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
class DiaFeatureExtractionTester:
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.__init__
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size=1,
padding_value=0.0,
sampling_rate=16000,
hop_length=512,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.hop_length = hop_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.feature_size = feature_size
self.padding_value = padding_value
self.sampling_rate = sampling_rate
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.prepare_feat_extract_dict
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"padding_value": self.padding_value,
"sampling_rate": self.sampling_rate,
"hop_length": self.hop_length,
}
# Copied from tests.models.encodec.test_feature_extraction_encodec.EnCodecFeatureExtractionTester.prepare_inputs_for_common
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
audio_inputs = floats_list((self.batch_size, self.max_seq_length))
else:
# make sure that inputs increase in size
audio_inputs = [
_flatten(floats_list((x, self.feature_size)))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
audio_inputs = [np.asarray(x) for x in audio_inputs]
return audio_inputs
@require_torch
class DiaFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = DiaFeatureExtractor
def setUp(self):
self.feat_extract_tester = DiaFeatureExtractionTester(self)
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_call
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
# Test not batched input
encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values
encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_double_precision_pad
def test_double_precision_pad(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_audio_inputs = np.random.rand(100).astype(np.float64)
py_audio_inputs = np_audio_inputs.tolist()
for inputs in [py_audio_inputs, np_audio_inputs]:
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_values.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest._load_datasamples
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in audio_samples]
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_integration with Dac->Dia
def test_integration(self):
# fmt: off
EXPECTED_INPUT_VALUES = torch.tensor(
[ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03,
1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04,
9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03,
1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04,
1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03,
7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03,
8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04,
1.0986328e-03, 2.1057129e-03]
)
# fmt: on
input_audio = self._load_datasamples(1)
feature_extractor = DiaFeatureExtractor()
input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"]
self.assertEqual(input_values.shape, (1, 1, 93696))
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32)
torch.testing.assert_close(input_values[0, 0, -46:-16], audio_input_end, rtol=1e-4, atol=1e-4)
def test_integration_stereo(self):
# fmt: off
EXPECTED_INPUT_VALUES = torch.tensor(
[2.3804e-03, 2.0752e-03, 1.9836e-03, 2.1057e-03, 1.6174e-03,
3.0518e-04, 9.1553e-05, 3.3569e-04, 9.7656e-04, 1.8311e-03,
2.0142e-03, 2.1057e-03, 1.7395e-03, 4.5776e-04, -3.9673e-04,
4.5776e-04, 1.0071e-03, 9.1553e-05, 4.8828e-04, 1.1597e-03,
7.3242e-04, 9.4604e-04, 1.8005e-03, 1.8311e-03, 8.8501e-04,
4.2725e-04, 4.8828e-04, 7.3242e-04, 1.0986e-03, 2.1057e-03]
)
# fmt: on
input_audio = self._load_datasamples(1)
input_audio = [np.tile(input_audio[0][None], reps=(2, 1))]
feature_extractor = DiaFeatureExtractor(feature_size=2)
input_values = feature_extractor(input_audio, return_tensors="pt").input_values
self.assertEqual(input_values.shape, (1, 1, 93696))
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_truncation_and_padding with Dac->Dia
def test_truncation_and_padding(self):
input_audio = self._load_datasamples(2)
# would be easier if the stride was like
feature_extractor = DiaFeatureExtractor()
# pad and trunc raise an error ?
with self.assertRaisesRegex(
ValueError,
"^Both padding and truncation were set. Make sure you only set one.$",
):
truncated_outputs = feature_extractor(
input_audio, padding="max_length", truncation=True, return_tensors="pt"
).input_values
# force truncate to max_length
truncated_outputs = feature_extractor(
input_audio, truncation=True, max_length=48000, return_tensors="pt"
).input_values
self.assertEqual(truncated_outputs.shape, (2, 1, 48128))
# pad:
padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values
self.assertEqual(padded_outputs.shape, (2, 1, 93696))
# force pad to max length
truncated_outputs = feature_extractor(
input_audio, padding="max_length", max_length=100000, return_tensors="pt"
).input_values
self.assertEqual(truncated_outputs.shape, (2, 1, 100352))
# force no pad
with self.assertRaisesRegex(
ValueError,
"^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$",
):
truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values
truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values
self.assertEqual(truncated_outputs.shape, (1, 1, 93680))

View File

@ -0,0 +1,752 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Dia model."""
import copy
import pathlib
import tempfile
import unittest
import pytest
from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
from transformers.testing_utils import (
cleanup,
is_flaky,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_datasets_available():
from datasets import Audio, load_dataset
if is_torch_available():
import torch
from transformers import (
DiaForConditionalGeneration,
DiaModel,
DiaProcessor,
PretrainedConfig,
PreTrainedModel,
)
from transformers.cache_utils import (
Cache,
StaticCache,
)
from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder
if is_torchaudio_available():
import torchaudio
if is_soundfile_available():
import soundfile as sf
@require_torch
class DiaModelTester:
def __init__(
self,
parent,
batch_size=3, # need batch_size != num_hidden_layers
seq_length=7,
max_length=50,
is_training=True,
vocab_size=100,
hidden_size=16,
intermediate_size=37,
num_hidden_layers=2,
num_attention_heads=2,
head_dim=8,
decoder_hidden_size=32, # typically larger than encoder
hidden_act="silu",
eos_token_id=97, # special tokens all occur after eos
pad_token_id=98,
bos_token_id=99,
delay_pattern=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.max_length = max_length
self.is_training = is_training
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.decoder_hidden_size = decoder_hidden_size
self.hidden_act = hidden_act
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
# Set default delay pattern if not provided
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2]
self.num_channels = len(self.delay_pattern)
def get_config(self):
encoder_config = DiaEncoderConfig(
max_position_embeddings=self.max_length,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing
head_dim=self.head_dim,
intermediate_size=self.intermediate_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
)
decoder_config = DiaDecoderConfig(
max_position_embeddings=self.max_length,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.decoder_hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=1, # GQA
head_dim=self.head_dim,
cross_num_attention_heads=self.num_attention_heads,
cross_head_dim=self.head_dim,
cross_num_key_value_heads=1, # GQA
cross_hidden_size=self.hidden_size, # match encoder hidden size
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
num_channels=self.num_channels,
)
config = DiaConfig(
encoder_config=encoder_config,
decoder_config=decoder_config,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
delay_pattern=self.delay_pattern,
)
return config
def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]:
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = input_ids.ne(self.pad_token_id)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size)
decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id)
config = self.get_config()
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]:
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def create_and_check_model_forward(self, config, inputs_dict):
model = DiaModel(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
# first forward pass
last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
self.parent.assertTrue(
last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size)
)
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = DiaModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(
input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3)
@require_torch
class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else ()
# We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate`
all_generative_model_classes = (DiaForConditionalGeneration,)
# TODO: support new pipeline behavior in tests
pipeline_model_mapping = {}
# pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
test_resize_embeddings = False
is_encoder_decoder = True
# Indicates VLMs usually but there are many audio models which are also composite
_is_composite = True
def setUp(self):
self.model_tester = DiaModelTester(self)
# Skipping `has_text_modality` but manually testing down below
self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig)
self.skip_non_greedy_generate()
def skip_non_greedy_generate(self):
skippable_tests = [
"test_sample_generate_dict_output", # return sequences > 1
"test_beam",
"test_group_beam",
"test_constrained_beam",
"test_contrastive",
"test_assisted",
"test_dola",
"test_prompt_lookup",
"test_model_parallel_beam_search",
"test_generate_without_input_ids",
"test_generate_with_head_masking",
]
for test in skippable_tests:
if self._testMethodName.startswith(test):
self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
"""Overriden to account for the 2D flattened structure"""
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
inputs_dict["labels"] = torch.ones(
(
self.model_tester.batch_size * self.model_tester.num_channels,
self.model_tester.seq_length,
),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def test_config(self):
self.config_tester.run_common_tests()
# Manual testing because of composite configs
config = self.model_tester.prepare_config_and_inputs()[0]
self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist")
self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist")
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
@is_flaky
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config
# + additional special cases where 3D x 2D meshes confuse the expected shape
def _check_logits(self, batch_size, logits, config):
batch_size *= len(config.delay_pattern) # Account for flattening
vocab_size = config.decoder_config.vocab_size
self.assertIsInstance(logits, tuple)
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = vocab_size - logits[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (output_length - prompt_length))
use_cache = decoder_past_key_values is not None
has_static_cache = isinstance(decoder_past_key_values, StaticCache)
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
# token(s)
for generated_length, iter_attentions in enumerate(attentions):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
query_length = (
prompt_length + generated_length
if not has_static_cache
else decoder_past_key_values.get_max_cache_shape()
)
expected_shape = (
batch_size,
config.decoder_config.num_attention_heads, # Decoder config
model_input_length,
query_length,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
# Encoder config
encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
# new token(s)
for generated_length, iter_hidden_states in enumerate(hidden_states):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
# check hidden size
# we can have different hidden sizes between encoder and decoder --> check both
expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size)
expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size)
self.assertTrue(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
== [expected_shape_encoder] * len(iter_hidden_states)
or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
== [expected_shape_decoder] * len(iter_hidden_states)
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
# Encoder config
encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# we need the decoder config here
config = config.decoder_config
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
cache_length,
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads,
)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
)
def _check_scores(self, batch_size, scores, generated_length, config):
# Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab)
vocab_size = config.decoder_config.vocab_size
expected_shape = (batch_size * len(config.delay_pattern), vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), generated_length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Overwritten as it relies on hardcoded namings atm - checking for our case here specifically
"""
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
sub_models_supporting_sdpa = [
(module._supports_sdpa or module._supports_attention_backend)
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_sdpa_all_modules = (
all(sub_models_supporting_sdpa)
if len(sub_models_supporting_sdpa) > 0
else (model._supports_sdpa or model._supports_attention_backend)
)
if not supports_sdpa_all_modules:
with self.assertRaises(ValueError):
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
else:
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
for key in model_sdpa.config:
if isinstance(getattr(model_sdpa.config, key), PretrainedConfig):
sub_config = getattr(model_sdpa.config, key)
self.assertTrue(sub_config._attn_implementation == "sdpa")
@pytest.mark.generate
@unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.")
def test_generate_continue_from_past_key_values(self):
"""Only a small change due to the expected shapes"""
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_past_key_values_format(self, custom_all_cache_shapes=None):
pass
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_hidden_states_output(self):
pass
@unittest.skip(
reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings"
)
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
def test_torchscript_output_hidden_state(self):
pass
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
def test_torchscript_simple(self):
pass
@unittest.skip(reason="Encoder-Decoder cache can not be initialized.")
def test_multi_gpu_data_parallel_forward(self):
pass
class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
"""
See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests
NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia
(It doesn't change the behaviour as we cut by the eos token position)
"""
def setUp(self):
# it's a dummy ckpt but should suffice for testing purposes
self.model_checkpoint = "AntonV/Dia-1.6B"
self.sampling_rate = 44100
# prepare audio
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate))
audio_sample_1 = librispeech_dummy[-1]["audio"]["array"]
audio_sample_2 = librispeech_dummy[-2]["audio"]["array"]
# 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia
dac_chunk_len = 512
self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3"
self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3"
sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate)
sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate)
def tearDown(self):
pathlib.Path(self.audio_prompt_1_path).unlink()
pathlib.Path(self.audio_prompt_2_path).unlink()
cleanup(torch_device, gc_collect=True)
@slow
@require_torch_accelerator
def test_dia_model_integration_generate_tts(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026],
[ 568, 804, 10, 674, 364, 981, 728, 1026, 1026],
[ 568, 804, 10, 674, 364, 981, 741, 550, 1026],
[ 568, 804, 10, 674, 364, 981, 568, 378, 90],
[1024, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 1024, 10, 674, 364, 981, 568, 378, 731],
[1025, 1025, 1024, 674, 364, 981, 568, 378, 731],
[1025, 1025, 1025, 1024, 364, 981, 568, 378, 731],
[1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]],
[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026],
[ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026],
[ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026],
[ 592, 402, 386, 347, 362, 981, 728, 1026, 1026],
[ 592, 402, 721, 728, 327, 981, 741, 550, 1026],
[ 592, 402, 721, 728, 460, 62, 676, 378, 90],
[1024, 402, 721, 728, 837, 595, 195, 982, 784],
[1025, 402, 721, 677, 497, 102, 692, 24, 330],
[1025, 402, 721, 677, 511, 102, 503, 871, 609],
[1025, 402, 721, 677, 511, 96, 801, 871, 894],
[1025, 402, 721, 677, 511, 745, 314, 498, 775],
[1025, 402, 721, 677, 511, 745, 314, 498, 105],
[1025, 402, 721, 677, 511, 745, 314, 861, 889],
[1025, 893, 721, 677, 511, 744, 314, 871, 353],
[1025, 1024, 888, 677, 511, 744, 314, 871, 332],
[1025, 1025, 1024, 518, 511, 744, 314, 871, 366],
[1025, 1025, 1025, 1024, 611, 744, 314, 871, 366],
[1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366],
[1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]])
# fmt: on
torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS)
@slow
@require_torch_accelerator
def test_dia_model_integration_generate_audio_context(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy()
audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy()
audio = [audio_sample_1, audio_sample_2]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
# dia has right padding while we have left padding (for faster prefill)
# additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings)
outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False)
outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026],
[ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026],
[ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026],
[ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026],
[ 330, 567, 530, 753, 607, 179, 954, 242, 1026],
[ 330, 627, 6, 1010, 500, 189, 598, 858, 247],
[1024, 432, 480, 530, 122, 3, 788, 149, 814],
[1025, 875, 826, 458, 98, 540, 181, 122, 608],
[1025, 495, 840, 413, 337, 784, 591, 150, 1017],
[1025, 808, 189, 137, 445, 0, 227, 658, 345],
[1025, 397, 89, 753, 1016, 173, 984, 0, 910],
[1025, 875, 460, 934, 50, 335, 670, 818, 722],
[1025, 875, 460, 762, 119, 372, 503, 858, 584],
[1025, 348, 555, 475, 469, 458, 963, 41, 664],
[1025, 1024, 852, 683, 761, 193, 595, 895, 885],
[1025, 1025, 1024, 135, 761, 902, 163, 623, 385],
[1025, 1025, 1025, 1024, 852, 282, 581, 623, 70],
[1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977],
[1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026],
[ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026],
[ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026],
[ 315, 249, 678, 868, 899, 257, 950, 1026, 1026],
[ 315, 249, 217, 471, 292, 908, 196, 469, 1026],
[ 315, 249, 825, 771, 839, 802, 633, 590, 531],
[1024, 249, 150, 53, 126, 76, 794, 626, 442],
[1025, 249, 825, 218, 359, 864, 526, 626, 770],
[1025, 249, 150, 137, 530, 845, 877, 600, 111],
[1025, 249, 150, 287, 730, 991, 135, 259, 39],
[1025, 249, 825, 104, 198, 1020, 719, 625, 208],
[1025, 249, 825, 997, 602, 256, 859, 322, 518],
[1025, 668, 825, 979, 584, 256, 98, 665, 589],
[1025, 954, 458, 54, 206, 52, 244, 822, 599],
[1025, 1024, 104, 914, 435, 579, 860, 92, 661],
[1025, 1025, 1024, 848, 126, 74, 304, 92, 753],
[1025, 1025, 1025, 1024, 362, 376, 304, 586, 753],
[1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83],
[1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
# fmt: on
torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1)
torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding

View File

@ -0,0 +1,269 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from transformers import DacModel, DiaFeatureExtractor, DiaProcessor, DiaTokenizer
from transformers.testing_utils import require_torch
from transformers.utils import is_torch_available
if is_torch_available:
import torch
# Copied from tests.utils.test_modeling_utils.check_models_equal
def check_models_equal(model1, model2):
models_are_equal = True
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
if model1_p.data.ne(model2_p.data).sum() > 0:
models_are_equal = False
return models_are_equal
@require_torch
class DiaProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "AntonV/Dia-1.6B"
self.audio_tokenizer_checkpoint = "descript/dac_44khz"
self.tmpdirname = tempfile.mkdtemp()
# Audio tokenizer is a bigger model so we will reuse this if possible
self.processor = DiaProcessor(
tokenizer=self.get_tokenizer(),
feature_extractor=self.get_feature_extractor(),
audio_tokenizer=self.get_audio_tokenizer(),
)
# Default audio values based on Dia and Dac
self.pad_id = 1025
self.bos_id = 1026
self.dac_chunk_len = 512
self.delay_pattern = [0, 8, 9, 10, 11, 12, 13, 14, 15]
def get_tokenizer(self, **kwargs):
return DiaTokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return DiaFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def get_audio_tokenizer(self, **kwargs):
return DacModel.from_pretrained(self.audio_tokenizer_checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
del self.processor
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
audio_tokenizer = self.get_audio_tokenizer()
processor = DiaProcessor(
tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer
)
processor.save_pretrained(self.tmpdirname)
processor = DiaProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer.__class__.__name__)
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer.name_or_path)
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer))
self.assertIsInstance(processor.audio_tokenizer, DacModel)
def test_save_load_pretrained_additional_features(self):
processor = DiaProcessor(
tokenizer=self.get_tokenizer(),
feature_extractor=self.get_feature_extractor(),
audio_tokenizer=self.get_audio_tokenizer(),
)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer()
feature_extractor_add_kwargs = self.get_feature_extractor()
audio_tokenizer_add_kwargs = self.get_audio_tokenizer()
processor = DiaProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer_add_kwargs.__class__.__name__)
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer_add_kwargs.name_or_path)
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer_add_kwargs))
self.assertIsInstance(processor.audio_tokenizer, DacModel)
def test_model_input_names(self):
tokenizer = self.get_tokenizer()
self.assertListEqual(
self.processor.model_input_names,
list(dict.fromkeys(tokenizer.model_input_names + ["decoder_input_ids", "decoder_attention_mask"])),
msg="`processor` model input names do not match the expected names.",
)
def test_tokenize(self):
tokenizer = self.get_tokenizer()
random_text = ["This is a processing test for tokenization", "[S1] Dia template style [S2] Nice"]
input_tokenizer = tokenizer(random_text, padding=True, return_tensors="pt")
input_processor = self.processor(random_text)
for key in input_tokenizer.keys():
self.assertTrue((input_tokenizer[key] == input_processor[key]).all())
def test_no_audio(self):
random_text = ["Dummy Input"] * 2
input_processor = self.processor(random_text)
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
# full mask with +1 for bos
self.assertTrue(audio_mask.sum() == (max(self.delay_pattern) + 1) * len(random_text))
self.assertTrue(
audio_tokens.shape
== (
len(random_text),
max(self.delay_pattern) + 1,
len(self.delay_pattern),
)
)
for channel_idx, delay in enumerate(self.delay_pattern):
expected_sequence = torch.ones(size=(audio_tokens.shape[:-1])) * self.pad_id
expected_sequence[:, : delay + 1] = self.bos_id
self.assertTrue((audio_tokens[..., channel_idx] == expected_sequence).all())
def test_audio(self):
audio_tokenizer = self.get_audio_tokenizer()
feature_extractor = self.get_feature_extractor()
random_text = ["Dummy Input"] * 2
# Dac only starts accepting audio from a certain length (ensured via >=1024)
raw_speeches = [np.random.rand(2048).astype(np.float32), np.random.rand(1024).astype(np.float32)]
input_processor = self.processor(random_text, raw_speeches)
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
sequence_len = audio_mask.shape[1]
for batch_idx, speech in enumerate(raw_speeches):
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
codebooks = audio_tokenizer(raw_audio).audio_codes.transpose(1, 2)
pad_len = sequence_len - audio_mask.sum(dim=-1)[batch_idx]
for channel_idx, delay in enumerate(self.delay_pattern):
# Left padding filled bos, right padding (delay) are pad
start_idx = pad_len + delay + 1
end_idx = start_idx + codebooks.shape[1]
encoded_sequence = audio_tokens[batch_idx, :, channel_idx]
expected_sequence = torch.ones(size=(sequence_len,)) * self.pad_id
expected_sequence[:start_idx] = self.bos_id
expected_sequence[start_idx:end_idx] = codebooks[0, :, channel_idx]
self.assertTrue((encoded_sequence == expected_sequence).all())
# Just to make sure the masking correctly only ignores bos tokens
self.assertTrue((audio_tokens[~audio_mask.bool()] == self.bos_id).all())
@parameterized.expand([([1, 1],), ([1, 5],), ([2, 4, 6],)])
def test_decode_audio(self, audio_lens):
feature_extractor = self.get_feature_extractor()
audio_tokenizer = self.get_audio_tokenizer()
random_text = ["Dummy Input"] * len(audio_lens)
raw_speeches = [np.random.rand(self.dac_chunk_len * l).astype(np.float32) for l in audio_lens]
# we need eos (given if training) to decode properly, also enforced via custom logits processor
input_processor = self.processor(random_text, raw_speeches, generation=False)
audio_tokens = input_processor["decoder_input_ids"]
decoded_speeches = self.processor.batch_decode(audio_tokens)
for batch_idx, speech in enumerate(raw_speeches):
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
codebooks = audio_tokenizer(raw_audio).audio_codes
decoded_audio = decoded_speeches[batch_idx]
expected_audio = audio_tokenizer.decode(audio_codes=codebooks).audio_values
self.assertTrue((expected_audio == decoded_audio).all())
self.assertTrue(decoded_speeches[batch_idx].shape[-1] == audio_lens[batch_idx] * self.dac_chunk_len)
@parameterized.expand([(1, 2, [0, 1, 4]), (2, 4, [1, 3, 2]), (4, 8, [0, 5, 7])])
def test_delay_in_audio(self, bsz, seq_len, delay_pattern):
# static functions which are crucial, hence we also test them here
build_indices_fn = DiaProcessor.build_indices
delay_fn = DiaProcessor.apply_audio_delay
bos, pad = -2, -1
num_channels = len(delay_pattern)
audio_input = torch.arange(bsz * seq_len * num_channels).view(bsz, seq_len, num_channels)
# imitate a delay mask with zeroes
audio_input = torch.cat([audio_input, torch.zeros(size=(bsz, max(delay_pattern), num_channels))], dim=1)
precomputed_idx = build_indices_fn(
bsz=bsz,
seq_len=seq_len + max(delay_pattern),
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=False,
)
delayed_audio_out = delay_fn(
audio=audio_input,
pad_token_id=pad,
bos_token_id=bos,
precomputed_idx=precomputed_idx,
)
# every channel idx is shifted by delay_pattern[idx]
delayed_audio_res = audio_input.clone()
for idx, delay in enumerate(delay_pattern):
delayed_audio_res[:, :delay, idx] = bos
remaining_input = seq_len + max(delay_pattern) - delay
delayed_audio_res[:, delay:, idx] = audio_input[:, :remaining_input, idx]
self.assertTrue((delayed_audio_out == delayed_audio_res).all())
# we should get back to the original audio we had (when removing the delay pad)
bsz, new_seq_len, num_channels = delayed_audio_out.shape
precomputed_idx = build_indices_fn(
bsz=bsz,
seq_len=new_seq_len,
num_channels=num_channels,
delay_pattern=delay_pattern,
revert=True,
)
reverted_audio_out = delay_fn(
audio=delayed_audio_out,
pad_token_id=pad,
bos_token_id=bos,
precomputed_idx=precomputed_idx,
)
reverted_audio_res = audio_input.clone()[:, :seq_len]
self.assertTrue((reverted_audio_out[:, :seq_len] == reverted_audio_res).all())

View File

@ -0,0 +1,123 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.models.dia import DiaTokenizer
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
# Special tokens
PAD = 0
S1 = 1
S2 = 2
class DiaTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = DiaTokenizer
test_rust_tokenizer = False
@classmethod
def setUpClass(cls):
super().setUpClass()
tokenizer = DiaTokenizer()
tokenizer.save_pretrained(cls.tmpdirname)
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "i"
token_id = 105
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[PAD], "<pad>")
self.assertEqual(vocab_keys[S1], "[S1]")
self.assertEqual(vocab_keys[S2], "[S2]")
self.assertEqual(len(vocab_keys), 256)
def test_vocab_size(self):
# utf-8 == 2**8 == 256
self.assertEqual(self.get_tokenizer().vocab_size, 256)
def test_full_tokenizer(self):
tokenizer = DiaTokenizer.from_pretrained(self.tmpdirname)
tokens = tokenizer.tokenize("Hello, world!")
self.assertListEqual(tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
tokens = tokenizer.tokenize("[S1] Hello [S2] Hello<pad>")
self.assertListEqual(
tokens,
["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", "<pad>"],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [S1, 32, 72, 101, 108, 108, 111, 32, S2, 32, 72, 101, 108, 108, 111, PAD])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens, ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", "<pad>"]
)
@slow
def test_tokenizer_integration(self):
# Overwritten as decoding will lead to all single bytes (i.e. characters) while usually the string format is expected
expected_encoding = {'input_ids': [[84, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 40, 102, 111, 114, 109, 101, 114, 108, 121, 32, 107, 110, 111, 119, 110, 32, 97, 115, 32, 112, 121, 116, 111, 114, 99, 104, 45, 116, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 97, 110, 100, 32, 112, 121, 116, 111, 114, 99, 104, 45, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 45, 98, 101, 114, 116, 41, 32, 112, 114, 111, 118, 105, 100, 101, 115, 32, 103, 101, 110, 101, 114, 97, 108, 45, 112, 117, 114, 112, 111, 115, 101, 32, 97, 114, 99, 104, 105, 116, 101, 99, 116, 117, 114, 101, 115, 32, 40, 66, 69, 82, 84, 44, 32, 71, 80, 84, 45, 50, 44, 32, 82, 111, 66, 69, 82, 84, 97, 44, 32, 88, 76, 77, 44, 32, 68, 105, 115, 116, 105, 108, 66, 101, 114, 116, 44, 32, 88, 76, 78, 101, 116, 46, 46, 46, 41, 32, 102, 111, 114, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 85, 110, 100, 101, 114, 115, 116, 97, 110, 100, 105, 110, 103, 32, 40, 78, 76, 85, 41, 32, 97, 110, 100, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 71, 101, 110, 101, 114, 97, 116, 105, 111, 110, 32, 40, 78, 76, 71, 41, 32, 119, 105, 116, 104, 32, 111, 118, 101, 114, 32, 51, 50, 43, 32, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 32, 109, 111, 100, 101, 108, 115, 32, 105, 110, 32, 49, 48, 48, 43, 32, 108, 97, 110, 103, 117, 97, 103, 101, 115, 32, 97, 110, 100, 32, 100, 101, 101, 112, 32, 105, 110, 116, 101, 114, 111, 112, 101, 114, 97, 98, 105, 108, 105, 116, 121, 32, 98, 101, 116, 119, 101, 101, 110, 32, 74, 97, 120, 44, 32, 80, 121, 84, 111, 114, 99, 104, 32, 97, 110, 100, 32, 84, 101, 110, 115, 111, 114, 70, 108, 111, 119, 46], [66, 69, 82, 84, 32, 105, 115, 32, 100, 101, 115, 105, 103, 110, 101, 100, 32, 116, 111, 32, 112, 114, 101, 45, 116, 114, 97, 105, 110, 32, 100, 101, 101, 112, 32, 98, 105, 100, 105, 114, 101, 99, 116, 105, 111, 110, 97, 108, 32, 114, 101, 112, 114, 101, 115, 101, 110, 116, 97, 116, 105, 111, 110, 115, 32, 102, 114, 111, 109, 32, 117, 110, 108, 97, 98, 101, 108, 101, 100, 32, 116, 101, 120, 116, 32, 98, 121, 32, 106, 111, 105, 110, 116, 108, 121, 32, 99, 111, 110, 100, 105, 116, 105, 111, 110, 105, 110, 103, 32, 111, 110, 32, 98, 111, 116, 104, 32, 108, 101, 102, 116, 32, 97, 110, 100, 32, 114, 105, 103, 104, 116, 32, 99, 111, 110, 116, 101, 120, 116, 32, 105, 110, 32, 97, 108, 108, 32, 108, 97, 121, 101, 114, 115, 46], [84, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 115, 32, 111, 118, 101, 114, 32, 116, 104, 101, 32, 108, 97, 122, 121, 32, 100, 111, 103, 46]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
sequences = [
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained "
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
"conditioning on both left and right context in all layers.",
"The quick brown fox jumps over the lazy dog.",
]
tokenizer_classes = [self.tokenizer_class]
if self.test_rust_tokenizer:
tokenizer_classes.append(self.rust_tokenizer_class)
for tokenizer_class in tokenizer_classes:
tokenizer = tokenizer_class.from_pretrained("AntonV/Dia-1.6B")
encoding = tokenizer(sequences)
encoding_data = encoding.data
self.assertDictEqual(encoding_data, expected_encoding)
# Byte decoding leads to characters so we need to join them
decoded_sequences = [
"".join(tokenizer.decode(seq, skip_special_tokens=True)) for seq in encoding["input_ids"]
]
for expected, decoded in zip(sequences, decoded_sequences):
if self.test_sentencepiece_ignore_case:
expected = expected.lower()
self.assertEqual(expected, decoded)
@unittest.skip(reason="Dia relies on whole input string due to the byte-level nature.")
def test_pretokenized_inputs(self):
pass
@unittest.skip
def test_tokenizer_slow_store_full_signature(self):
pass

View File

@ -4574,6 +4574,11 @@ class ModelTesterMixin:
head_dim = config.head_dim
config.head_dim = max(16, config.head_dim)
cross_head_dim = None
if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None:
cross_head_dim = config.cross_head_dim
config.cross_head_dim = max(16, config.cross_head_dim)
if (
getattr(config, "hidden_size", None) is not None
and getattr(config, "num_attention_heads", None) is not None
@ -4588,6 +4593,17 @@ class ModelTesterMixin:
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
if (
getattr(config, "cross_hidden_size", None) is not None
and getattr(config, "cross_num_attention_heads", None) is not None
):
cross_head_dim = (
cross_head_dim
if cross_head_dim is not None
else config.cross_hidden_size // config.cross_num_attention_heads
)
config.cross_hidden_size *= max(16 // cross_head_dim, 1)
# Set default attention to flex and update config values
update_config_for_flex(config)
for key in config.sub_configs:

View File

@ -32,6 +32,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
# used internally during generation to provide the custom logit processors with their necessary information
"DiaConfig": [
"delay_pattern",
],
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
# periods and offsets are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
"BambaConfig": [