mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add Dia model (#38405)
* add dia model * add tokenizer files * cleanup some stuff * brut copy paste code * rough cleanup of the modeling code * nuke some stuff * more nuking * more cleanups * updates * add mulitLayerEmbedding vectorization * nits * more modeling simplifications * updates * update rope * update rope * just fixup * update configuration files * more cleanup! * default config values * update * forgotten comma * another comma! * update, more cleanups * just more nits * more config cleanups * time for the encoder * fix * sa=mall nit * nits * n * refacto a bit * cleanup * update cv scipt * fix last issues * fix last nits * styling * small fixes * just run 1 generation * fixes * nits * fix conversion * fix * more fixes * full generate * ouf! * fixes! * updates * fix * fix cvrt * fixup * nits * delete wrong test * update * update * test tokenization * let's start changing things bit by bit - fix encoder step * removing custom generation, moving to GenerationMixin * add encoder decoder attention masks for generation * mask changes, correctness checked against ad29837 in dia repo * refactor a bit already --> next cache * too important not to push :) * minimal cleanup + more todos * make main overwrite modeling utils * add cfg filter & eos filter * add eos countdown & delay pattern * update eos countdown * add max step eos countdown * fix tests * fix some things * fix generation with testing * move cfg & eos stuff to logits processor * make RepetitionPenaltyLogitsProcessor flexible - can accept 3D scores like (batch_size, channel, vocab) * fix input_ids concatenation dimension in GenerationMixin for flexibility * Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty classes; refactor logits processing in DiaForConditionalGeneration to utilize new configurations and improve flexibility. * Add stopping criteria * refactor * move delay pattern from processor to modeling like musicgen. - add docs - change eos countdown to eos delay pattern * fix processor & fix tests * refactor types * refactor imports * format code * fix docstring to pass ci * add docstring to DiaConfig & add DiaModel to test * fix docstring * add docstring * fix some bugs * check * porting / merging results from other branch - IMPORTANT: it very likely breaks generation, the goal is to have a proper forward path first * experimental testing of left padding for first channel * whoops * Fix merge to make generation work * fix cfg filter * add position ids * add todos, break things * revert changes to generation --> we will force 2d but go 3d on custom stuff * refactor a lot, change prepare decoder ids to work with left padding (needs testing), add todos * some first fixes to get to 10. in generation * some more generation fixes / adjustment * style + rope fixes * move cfg out, simplify a few things, more todos * nit * start working on custom logit processors * nit * quick fixes * cfg top k * more refactor of logits processing, needs a decision if gen config gets the new attributes or if we move it to config or similar * lets keep changes to core code minimal, only eos scaling is questionable atm * simpler eos delay logits processor * that was for debugging :D * proof of concept rope * small fix on device mismatch * cfg fixes + delay logits max len * transformers rope * modular dia * more cleanup * keep modeling consistently 3D, generate handles 2D internally * decoder starts with bos if nothing * post processing prototype * style * lol * force sample / greedy + fixes on padding * style * fixup tokenization * nits * revert * start working on dia tests * fix a lot of tests * more test fixes * nit * more test fixes + some features to simplify code more * more cleanup * forgot that one * autodocs * small consistency fixes * fix regression * small fixes * dia feature extraction * docs * wip processor * fix processor order * processing goes brrr * transpose before * small fix * fix major bug but needs now a closer look into the custom processors esp cfg * small thing on logits * nits * simplify indices and shifts * add simpler version of padding tests back (temporarily) * add logit processor tests * starting tests on processor * fix mask application during generation * some fixes on the weights conversion * style + fixup logits order * simplify conversion * nit * remove padding tests * nits on modeling * hmm * fix tests * trigger * probably gonna be reverted, just a quick design around audio tokenizer * fixup typing * post merge + more typing * initial design for audio tokenizer * more design changes * nit * more processor tests and style related things * add to init * protect import * not sure why tbh * add another protect * more fixes * wow * it aint stopping :D * another missed type issue * ... * change design around audio tokenizer to prioritize init and go for auto - in regards to the review * change to new causal mask function + docstrings * change ternary * docs * remove todo, i dont think its essential tbh * remove pipeline as current pipelines do not fit in the current scheme, same as csm * closer to wrapping up the processor * text to audio, just for demo purposes (will likely be reverted) * check if it's this * save audio function * ensure no grad * fixes on prefixed audio, hop length is used via preprocess dac, device fixes * integration tests (tested locally on a100) + some processor utils / fixes * style * nits * another round of smaller things * docs + some fixes (generate one might be big) * msytery solved * small fix on conversion * add abstract audio tokenizer, change init check to abstract class * nits * update docs + fix some processing :D * change inheritance scheme for audio tokenizer * delete dead / unnecessary code in copied generate loop * last nits on new pipeline behavior (+ todo on tests) + style * trigger --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Vasqu <antonprogamer@gmail.com>
This commit is contained in:
parent
5995cfa0a0
commit
583db52bc6
@ -839,6 +839,8 @@
|
||||
title: CSM
|
||||
- local: model_doc/dac
|
||||
title: dac
|
||||
- local: model_doc/dia
|
||||
title: Dia
|
||||
- local: model_doc/encodec
|
||||
title: EnCodec
|
||||
- local: model_doc/fastspeech2_conformer
|
||||
|
@ -350,6 +350,10 @@ The following auto classes are available for the following audio tasks.
|
||||
|
||||
[[autodoc]] AutoModelForTextToWaveform
|
||||
|
||||
### AutoModelForAudioTokenization
|
||||
|
||||
[[autodoc]] AutoModelForAudioTokenization
|
||||
|
||||
## Multimodal
|
||||
|
||||
The following auto classes are available for the following multimodal tasks.
|
||||
|
162
docs/source/en/model_doc/dia.md
Normal file
162
docs/source/en/model_doc/dia.md
Normal file
@ -0,0 +1,162 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Dia
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
|
||||
Dia is an opensource text-to-speech (TTS) model (1.6B parameters) developed by [Nari Labs](https://huggingface.co/nari-labs).
|
||||
It can generate highly realistic dialogue from transcript including nonverbal communications such as laughter and coughing.
|
||||
Furthermore, emotion and tone control is also possible via audio conditioning (voice cloning).
|
||||
|
||||
**Model Architecture:**
|
||||
Dia is an encoder-decoder transformer based on the original transformer architecture. However, some more modern features such as
|
||||
rotational positional embeddings (RoPE) are also included. For its text portion (encoder), a byte tokenizer is utilized while
|
||||
for the audio portion (decoder), a pretrained codec model [DAC](./dac.md) is used - DAC encodes speech into discrete codebook
|
||||
tokens and decodes them back into audio.
|
||||
|
||||
## Usage Tips
|
||||
|
||||
### Generation with Text
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, DiaForConditionalGeneration
|
||||
|
||||
torch_device = "cuda"
|
||||
model_checkpoint = "buttercrab/dia-v1-1.6b"
|
||||
|
||||
text = ["[S1] Dia is an open weights text to dialogue model."]
|
||||
processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s
|
||||
|
||||
# save audio to a file
|
||||
outputs = processor.batch_decode(outputs)
|
||||
processor.save_audio(outputs, "example.wav")
|
||||
|
||||
```
|
||||
|
||||
### Generation with Text and Audio (Voice Cloning)
|
||||
|
||||
```python
|
||||
from datasets import load_dataset, Audio
|
||||
from transformers import AutoProcessor, DiaForConditionalGeneration
|
||||
|
||||
torch_device = "cuda"
|
||||
model_checkpoint = "buttercrab/dia-v1-1.6b"
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=44100))
|
||||
audio = ds[-1]["audio"]["array"]
|
||||
# text is a transcript of the audio + additional text you want as new audio
|
||||
text = ["[S1] I know. It's going to save me a lot of money, I hope. [S2] I sure hope so for you."]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
|
||||
prompt_len = processor.get_audio_prompt_len(inputs["decoder_attention_mask"])
|
||||
|
||||
model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=256) # corresponds to around ~2s
|
||||
|
||||
# retrieve actually generated audio and save to a file
|
||||
outputs = processor.batch_decode(outputs, audio_prompt_len=prompt_len)
|
||||
processor.save_audio(outputs, "example_with_audio.wav")
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
```python
|
||||
from datasets import load_dataset, Audio
|
||||
from transformers import AutoProcessor, DiaForConditionalGeneration
|
||||
|
||||
torch_device = "cuda"
|
||||
model_checkpoint = "buttercrab/dia-v1-1.6b"
|
||||
|
||||
ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=44100))
|
||||
audio = ds[-1]["audio"]["array"]
|
||||
# text is a transcript of the audio
|
||||
text = ["[S1] I know. It's going to save me a lot of money, I hope."]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_checkpoint)
|
||||
inputs = processor(
|
||||
text=text,
|
||||
audio=audio,
|
||||
generation=False,
|
||||
output_labels=True,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to(torch_device)
|
||||
|
||||
model = DiaForConditionalGeneration.from_pretrained(model_checkpoint).to(torch_device)
|
||||
out = model(**inputs)
|
||||
out.loss.backward()
|
||||
```
|
||||
|
||||
|
||||
This model was contributed by [Jaeyong Sung](https://huggingface.co/buttercrab), [Arthur Zucker](https://huggingface.co/ArthurZ),
|
||||
and [Anton Vlasjuk](https://huggingface.co/AntonV). The original code can be found [here](https://github.com/nari-labs/dia/).
|
||||
|
||||
|
||||
## DiaConfig
|
||||
|
||||
[[autodoc]] DiaConfig
|
||||
|
||||
## DiaDecoderConfig
|
||||
|
||||
[[autodoc]] DiaDecoderConfig
|
||||
|
||||
## DiaEncoderConfig
|
||||
|
||||
[[autodoc]] DiaEncoderConfig
|
||||
|
||||
## DiaTokenizer
|
||||
|
||||
[[autodoc]] DiaTokenizer
|
||||
- __call__
|
||||
|
||||
## DiaFeatureExtractor
|
||||
|
||||
[[autodoc]] DiaFeatureExtractor
|
||||
- __call__
|
||||
|
||||
## DiaProcessor
|
||||
|
||||
[[autodoc]] DiaProcessor
|
||||
- __call__
|
||||
- batch_decode
|
||||
- decode
|
||||
|
||||
## DiaModel
|
||||
|
||||
[[autodoc]] DiaModel
|
||||
- forward
|
||||
|
||||
## DiaForConditionalGeneration
|
||||
|
||||
[[autodoc]] DiaForConditionalGeneration
|
||||
- forward
|
||||
- generate
|
@ -271,7 +271,6 @@ class PretrainedConfig(PushToHubMixin):
|
||||
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
||||
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
||||
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
||||
|
||||
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
||||
|
||||
# task specific arguments
|
||||
|
@ -2975,3 +2975,224 @@ class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor):
|
||||
The expected mean g-value for watermarked text.
|
||||
"""
|
||||
return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size))
|
||||
|
||||
|
||||
class DiaClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||
r"""
|
||||
[`LogitsProcessor`] for classifier free guidance (CFG). Similar to the original
|
||||
`ClassifierFreeGuidanceLogitsProcessor` with some modifications on the overall
|
||||
calculation, e.g. conditioned logits centered, and an additional top k selection
|
||||
option.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This logits processor is exclusively compatible with
|
||||
[Dia](https://huggingface.co/docs/transformers/main/en/model_doc/dia)
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
guidance_scale (float):
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality.
|
||||
guidance_top_k (int, *optional*):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering. However, we do not keep
|
||||
the logits of the combined CFG output, but the conditioned output only.
|
||||
"""
|
||||
|
||||
def __init__(self, guidance_scale: float, guidance_top_k: Optional[int] = None):
|
||||
if guidance_scale > 1:
|
||||
self.guidance_scale = guidance_scale
|
||||
else:
|
||||
raise ValueError(
|
||||
"Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
|
||||
f"{guidance_scale}."
|
||||
)
|
||||
|
||||
self.guidance_top_k = guidance_top_k
|
||||
if self.guidance_top_k is not None and self.guidance_top_k < 1:
|
||||
raise ValueError(
|
||||
f"`guidance_top_k` has to be a strictly positive integer if given, but is {self.guidance_top_k}"
|
||||
)
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# simple check to make sure we have compatible batch sizes between our
|
||||
# logits scores (cond + uncond) and input ids (cond only)
|
||||
if scores.shape[0] != 2 * input_ids.shape[0]:
|
||||
raise ValueError(
|
||||
f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to "
|
||||
f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got "
|
||||
f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids."
|
||||
)
|
||||
# Base CFG with center on cond_logits
|
||||
unguided_bsz = scores.shape[0] // 2
|
||||
cond_logits, uncond_logits = scores.split(unguided_bsz, dim=0)
|
||||
scores_processed = cond_logits + (cond_logits - uncond_logits) * self.guidance_scale
|
||||
|
||||
# Optional CFG top k filtering
|
||||
if self.guidance_top_k is not None:
|
||||
# Create top k based on the combined CFG output
|
||||
_, top_k_indices = torch.topk(scores_processed, k=self.guidance_top_k, dim=-1)
|
||||
top_k_mask = torch.ones_like(scores_processed, dtype=torch.bool)
|
||||
top_k_mask = top_k_mask.scatter(dim=-1, index=top_k_indices, value=False)
|
||||
# Only return conditioned logits with top k
|
||||
scores_processed = cond_logits.masked_fill(top_k_mask, -float("inf"))
|
||||
|
||||
return scores_processed
|
||||
|
||||
|
||||
class DiaEOSChannelFilterLogitsProcessor(LogitsProcessor):
|
||||
r"""Specialized processor that ensures certain properties around EOS sampling:
|
||||
1. Only channel 0 can generate EOS
|
||||
2. If channel 0 has EOS with highest logit, it will be the only candidate
|
||||
3. If channel 0 has EOS not with highest logit, it will be suppressed
|
||||
|
||||
2. and 3. are especially important in contexts where we allow sampling to guarantee the
|
||||
respective tokens to be (not) sampled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This logits processor is exclusively compatible with
|
||||
[Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
num_channels (`int`):
|
||||
Number of audio codebooks. Simplifies access to the first channel on the logits.
|
||||
eos_token_id (`int`):
|
||||
The id of *end-of-sequence* token.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels: int, eos_token_id: int):
|
||||
if num_channels < 1:
|
||||
raise ValueError(f"Audio codebooks need at least one channel, but found {num_channels} channels.")
|
||||
if eos_token_id < 1:
|
||||
raise ValueError(f"Expected `eos_token_id` to be a positive integer, found {eos_token_id} instead.")
|
||||
|
||||
self.num_channels = num_channels
|
||||
self.eos_id = eos_token_id
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Reshape for easier channel indexing [B, C, V]
|
||||
scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
|
||||
|
||||
# EOS filter
|
||||
# 1. Condition: Only the first channel can generate the EOS token
|
||||
# Side condition of disabling generation of special tokens (e.g. audio pad, bos, ...)
|
||||
# (Assumes them to be greater than audio eos token position)
|
||||
scores[:, 1:, self.eos_id :] = torch.full_like(
|
||||
scores[:, 1:, self.eos_id :],
|
||||
fill_value=-float("inf"),
|
||||
)
|
||||
scores[:, 0, self.eos_id + 1 :] = torch.full_like(
|
||||
scores[:, 0, self.eos_id + 1 :],
|
||||
fill_value=-float("inf"),
|
||||
)
|
||||
|
||||
# 2+3 Conditions: Force/Suppress EOS if (not) highest logit
|
||||
# Reshape back to original shape
|
||||
scores = scores.view(-1, scores.shape[-1])
|
||||
|
||||
# Sample highest tokens
|
||||
top_logit_indices = torch.argmax(scores, dim=-1)
|
||||
|
||||
# 2. Force EOS
|
||||
eos_highest_mask = top_logit_indices == self.eos_id
|
||||
mask_eos_highest = torch.zeros_like(scores, dtype=torch.bool)
|
||||
mask_eos_highest[eos_highest_mask, : self.eos_id] = True
|
||||
scores = scores.masked_fill(mask_eos_highest, -float("inf"))
|
||||
|
||||
# 3. Suppress EOS
|
||||
eos_not_highest_mask = top_logit_indices != self.eos_id
|
||||
mask_eos_unless_highest = torch.zeros_like(scores, dtype=torch.bool)
|
||||
mask_eos_unless_highest[eos_not_highest_mask, self.eos_id] = True
|
||||
scores = scores.masked_fill(mask_eos_unless_highest, -float("inf"))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class DiaEOSDelayPatternLogitsProcessor(LogitsProcessor):
|
||||
r"""Special logits processor to handle the generation of the EOS token in Dia.
|
||||
This is due to the fact that Dia does not allow the generation of EOS in all
|
||||
channels except the first channel (C0).
|
||||
|
||||
Hence, based on the delay pattern, an EOS is forced after the respective delays
|
||||
in the channels. For example, if the delay pattern is [0, 2, 3, 4]:
|
||||
|
||||
s s+1 s+2 s+3 s+4 s+5 ...
|
||||
| | | | | |
|
||||
C0: EOS PAD PAD PAD PAD PAD ...
|
||||
C1: x x EOS PAD PAD PAD ...
|
||||
C2: x x x EOS PAD PAD ...
|
||||
C3: x x x x EOS PAD ...
|
||||
|
||||
If the first channel generated EOS at step s, channels Cx are forced to generate
|
||||
theirs at the respective delays (s+2, s+3, s+4). Subsequent padding tokens are
|
||||
handled by the `EosTokenCriteria` when an EOS has been detected.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This logits processor is exclusively compatible with
|
||||
[Dia](https://huggingface.co/docs/transformers/en/model_doc/dia).
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
delay_pattern (`List[int]`):
|
||||
The delays per channel in the audio codebooks.
|
||||
eos_token_id (`int`):
|
||||
The id of *end-of-sequence* token.
|
||||
max_generation_len (`int`):
|
||||
The max sequence length that can be generated.
|
||||
device (`str`, *optional*, defaults to `"cpu"`):
|
||||
The device to allocate the tensors on.
|
||||
"""
|
||||
|
||||
def __init__(self, delay_pattern: list[int], eos_token_id: int, max_generation_len: int, device: str = "cpu"):
|
||||
self.num_channels = len(delay_pattern)
|
||||
# Update during first iteration
|
||||
self.active_batches = None
|
||||
self.delay_pattern = torch.tensor(delay_pattern, device=device, dtype=torch.int)[None, :]
|
||||
self.eos_token_id = eos_token_id
|
||||
self.max_generation_len = max_generation_len - max(delay_pattern) - 1
|
||||
self.device = device
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Reshape for easier channel indexing [B, C, V]
|
||||
scores = scores.reshape(-1, self.num_channels, scores.shape[-1])
|
||||
|
||||
# Initialize / expand values on first iteration
|
||||
if self.active_batches is None:
|
||||
self.delay_pattern = self.delay_pattern.repeat(scores.shape[0], 1)
|
||||
self.active_batches = torch.zeros(size=(scores.shape[0],), device=self.device, dtype=torch.bool)
|
||||
|
||||
# Check if eos has been generated in any batch
|
||||
channel_generated_eos = torch.argmax(scores, dim=-1)[:, 0] == self.eos_token_id
|
||||
# Check if max len has been reached
|
||||
reached_max_len = input_ids.shape[1] == self.max_generation_len
|
||||
|
||||
# Update active batches
|
||||
self.active_batches |= channel_generated_eos
|
||||
self.active_batches |= reached_max_len
|
||||
|
||||
# Find channels that need to force eos
|
||||
forced_eos_channels = self.active_batches[:, None] & (self.delay_pattern == 0)
|
||||
# Use indexing to avoid issues on all `False` by having empty tensors in that case
|
||||
idx_bsz, idx_channel = forced_eos_channels.nonzero(as_tuple=True)
|
||||
|
||||
# Force eos if delay is kicking in
|
||||
scores[idx_bsz, idx_channel, :] = -float("inf")
|
||||
scores[idx_bsz, idx_channel, self.eos_token_id] = 0.0
|
||||
|
||||
# Reshape back to [B * C, V]
|
||||
scores = scores.reshape(-1, scores.shape[-1])
|
||||
|
||||
# Update amount of delay left for each channel
|
||||
self.delay_pattern -= self.active_batches[:, None].int()
|
||||
|
||||
return scores
|
||||
|
@ -26,6 +26,7 @@ import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from contextlib import contextmanager
|
||||
@ -5884,3 +5885,26 @@ class AttentionInterface(GeneralInterface):
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
|
||||
|
||||
|
||||
class PreTrainedAudioTokenizerBase(PreTrainedModel):
|
||||
"""
|
||||
Class that additionally defines the behavior of any `audio_tokenizer` to be added.
|
||||
Characteristic for any of them:
|
||||
1. Encode raw audio into discrete audio codebooks (with x channels)
|
||||
2. Decode from discrete audio codebooks back to raw audio
|
||||
It is possible that they can decode in different ways given a different representation
|
||||
but they are forced to support 2. nonetheless, e.g. see `DAC`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, input_values: torch.Tensor, *args, **kwargs):
|
||||
"""
|
||||
Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, audio_codes: torch.Tensor, *args, **kwargs):
|
||||
"""Decode from discrete audio codebooks back to raw audio"""
|
||||
pass
|
||||
|
@ -88,6 +88,7 @@ if TYPE_CHECKING:
|
||||
from .depth_anything import *
|
||||
from .depth_pro import *
|
||||
from .detr import *
|
||||
from .dia import *
|
||||
from .dialogpt import *
|
||||
from .diffllama import *
|
||||
from .dinat import *
|
||||
|
@ -106,6 +106,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("depth_pro", "DepthProConfig"),
|
||||
("deta", "DetaConfig"),
|
||||
("detr", "DetrConfig"),
|
||||
("dia", "DiaConfig"),
|
||||
("diffllama", "DiffLlamaConfig"),
|
||||
("dinat", "DinatConfig"),
|
||||
("dinov2", "Dinov2Config"),
|
||||
@ -478,6 +479,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("depth_pro", "DepthPro"),
|
||||
("deta", "DETA"),
|
||||
("detr", "DETR"),
|
||||
("dia", "Dia"),
|
||||
("dialogpt", "DialoGPT"),
|
||||
("diffllama", "DiffLlama"),
|
||||
("dinat", "DiNAT"),
|
||||
|
@ -55,6 +55,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("deformable_detr", "DeformableDetrFeatureExtractor"),
|
||||
("deit", "DeiTFeatureExtractor"),
|
||||
("detr", "DetrFeatureExtractor"),
|
||||
("dia", "DiaFeatureExtractor"),
|
||||
("dinat", "ViTFeatureExtractor"),
|
||||
("donut-swin", "DonutFeatureExtractor"),
|
||||
("dpt", "DPTFeatureExtractor"),
|
||||
|
@ -99,6 +99,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("depth_pro", "DepthProModel"),
|
||||
("deta", "DetaModel"),
|
||||
("detr", "DetrModel"),
|
||||
("dia", "DiaModel"),
|
||||
("diffllama", "DiffLlamaModel"),
|
||||
("dinat", "DinatModel"),
|
||||
("dinov2", "Dinov2Model"),
|
||||
@ -472,6 +473,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-text", "Data2VecTextForMaskedLM"),
|
||||
("deberta", "DebertaForMaskedLM"),
|
||||
("deberta-v2", "DebertaV2ForMaskedLM"),
|
||||
("dia", "DiaForConditionalGeneration"),
|
||||
("distilbert", "DistilBertForMaskedLM"),
|
||||
("electra", "ElectraForMaskedLM"),
|
||||
("encoder-decoder", "EncoderDecoderModel"),
|
||||
@ -1059,6 +1061,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("dia", "DiaForConditionalGeneration"),
|
||||
("granite_speech", "GraniteSpeechForConditionalGeneration"),
|
||||
("kyutai_speech_to_text", "KyutaiSpeechToTextForConditionalGeneration"),
|
||||
("moonshine", "MoonshineForConditionalGeneration"),
|
||||
@ -1629,6 +1632,12 @@ MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_TOKENIZATION_NAMES = OrderedDict(
|
||||
[
|
||||
("dac", "DacModel"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
@ -1737,6 +1746,8 @@ MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING = _LazyAutoMapping(
|
||||
|
||||
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_AUDIO_TOKENIZATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_TOKENIZATION_NAMES)
|
||||
|
||||
|
||||
class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
||||
@ -2034,6 +2045,15 @@ class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
|
||||
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
|
||||
|
||||
|
||||
class AutoModelForAudioTokenization(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_TOKENIZATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForAudioTokenization = auto_class_update(
|
||||
AutoModelForAudioTokenization, head_doc="audio tokenization through codebooks"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
@ -2059,6 +2079,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||
__all__ = [
|
||||
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_AUDIO_TOKENIZATION_MAPPING",
|
||||
"MODEL_FOR_AUDIO_XVECTOR_MAPPING",
|
||||
"MODEL_FOR_BACKBONE_MAPPING",
|
||||
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING",
|
||||
@ -2106,6 +2127,7 @@ __all__ = [
|
||||
"AutoBackbone",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
"AutoModelForAudioTokenization",
|
||||
"AutoModelForAudioXVector",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
|
@ -61,6 +61,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||
("clvp", "ClvpProcessor"),
|
||||
("colpali", "ColPaliProcessor"),
|
||||
("colqwen2", "ColQwen2Processor"),
|
||||
("dia", "DiaProcessor"),
|
||||
("emu3", "Emu3Processor"),
|
||||
("flava", "FlavaProcessor"),
|
||||
("fuyu", "FuyuProcessor"),
|
||||
|
@ -177,6 +177,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
"LlamaTokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("dia", ("DiaTokenizer", None)),
|
||||
(
|
||||
"diffllama",
|
||||
(
|
||||
|
@ -23,7 +23,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedAudioTokenizerBase
|
||||
from ...utils import ModelOutput, auto_docstring
|
||||
from .configuration_dac import DacConfig
|
||||
|
||||
@ -471,7 +471,7 @@ class DacEncoder(nn.Module):
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DacPreTrainedModel(PreTrainedModel):
|
||||
class DacPreTrainedModel(PreTrainedAudioTokenizerBase):
|
||||
config_class = DacConfig
|
||||
base_model_prefix = "dac"
|
||||
main_input_name = "input_values"
|
||||
|
31
src/transformers/models/dia/__init__.py
Normal file
31
src/transformers/models/dia/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_dia import *
|
||||
from .feature_extraction_dia import *
|
||||
from .generation_dia import *
|
||||
from .modeling_dia import *
|
||||
from .processing_dia import *
|
||||
from .tokenization_dia import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
376
src/transformers/models/dia/configuration_dia.py
Normal file
376
src/transformers/models/dia/configuration_dia.py
Normal file
@ -0,0 +1,376 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dia model configuration"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DiaEncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DiaEncoder`]. It is used to instantiate a Dia
|
||||
encoder according to the specified arguments, defining the encoder architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
Number of key and value heads for each attention layer in the Transformer encoder.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimensionality of the attention head.
|
||||
intermediate_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the normalization layers.
|
||||
vocab_size (`int`, *optional*, defaults to 256):
|
||||
Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`DiaModel`].
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"swish"` and `"gelu_new"` are supported.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
"""
|
||||
|
||||
model_type = "dia_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_position_embeddings: int = 1024,
|
||||
num_hidden_layers: int = 12,
|
||||
hidden_size: int = 1024,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 16,
|
||||
head_dim: int = 128,
|
||||
intermediate_size: int = 4096,
|
||||
norm_eps: float = 1e-5,
|
||||
vocab_size: int = 256,
|
||||
hidden_act: str = "silu",
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
initializer_range: float = 0.02,
|
||||
**kwargs,
|
||||
):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.norm_eps = norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
self.initializer_range = initializer_range
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class DiaDecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DiaDecoder`]. It is used to instantiate a Dia
|
||||
decoder according to the specified arguments, defining the decoder architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
max_position_embeddings (`int`, *optional*, defaults to 3072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 18):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the decoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 8192):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
Number of key and value heads for each attention layer in the Transformer decoder.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimensionality of the attention head.
|
||||
cross_num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each cross-attention layer in the Transformer decoder.
|
||||
cross_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimensionality of the cross-attention head.
|
||||
cross_num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
Number of key and value heads for each cross-attention layer in the Transformer decoder.
|
||||
cross_hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the cross-attention layers.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the normalization layers.
|
||||
vocab_size (`int`, *optional*, defaults to 1028):
|
||||
Vocabulary size of the Dia model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`DiaModel`].
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder. If string, `"gelu"`, `"relu"`,
|
||||
`"swish"` and `"gelu_new"` are supported.
|
||||
num_channels (`int`, *optional*, defaults to 9):
|
||||
Number of channels for the Dia decoder.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Indicating that this model is part of an encoder-decoder architecture.
|
||||
"""
|
||||
|
||||
model_type = "dia_decoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_position_embeddings: int = 3072,
|
||||
num_hidden_layers: int = 18,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 8192,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 4,
|
||||
head_dim: int = 128,
|
||||
cross_num_attention_heads: int = 16,
|
||||
cross_head_dim: int = 128,
|
||||
cross_num_key_value_heads: int = 16,
|
||||
cross_hidden_size: int = 1024,
|
||||
norm_eps: float = 1e-5,
|
||||
vocab_size: int = 1028,
|
||||
hidden_act: str = "silu",
|
||||
num_channels: int = 9,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
initializer_range: float = 0.02,
|
||||
use_cache: bool = True,
|
||||
is_encoder_decoder: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.cross_num_key_value_heads = cross_num_key_value_heads
|
||||
self.cross_num_attention_heads = cross_num_attention_heads
|
||||
self.cross_head_dim = cross_head_dim
|
||||
self.cross_hidden_size = cross_hidden_size
|
||||
self.norm_eps = norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_channels = num_channels
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||
|
||||
|
||||
class DiaConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DiaModel`]. It is used to instantiate a
|
||||
Dia model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of the
|
||||
[nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
encoder_config (`DiaEncoderConfig`, *optional*):
|
||||
Configuration for the encoder part of the model. If not provided, a default `DiaEncoderConfig` will be used.
|
||||
decoder_config (`DiaDecoderConfig`, *optional*):
|
||||
Configuration for the decoder part of the model. If not provided, a default `DiaDecoderConfig` will be used.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the normalization layers.
|
||||
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
|
||||
Indicating that this model uses an encoder-decoder architecture.
|
||||
pad_token_id (`int`, *optional*, defaults to 1025):
|
||||
Padding token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 1024):
|
||||
End of stream token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1026):
|
||||
Beginning of stream token id.
|
||||
delay_pattern (`list[int]`, *optional*, defaults to `[0, 8, 9, 10, 11, 12, 13, 14, 15]`):
|
||||
The delay pattern for the decoder. The length of this list must match `decoder_config.num_channels`.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import DiaConfig, DiaModel
|
||||
|
||||
>>> # Initializing a DiaConfig with default values
|
||||
>>> configuration = DiaConfig()
|
||||
|
||||
>>> # Initializing a DiaModel (with random weights) from the configuration
|
||||
>>> model = DiaModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "dia"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
sub_configs = {"encoder_config": DiaEncoderConfig, "decoder_config": DiaDecoderConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_config: Optional[DiaEncoderConfig] = None,
|
||||
decoder_config: Optional[DiaDecoderConfig] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
is_encoder_decoder: bool = True,
|
||||
pad_token_id: int = 1025,
|
||||
eos_token_id: int = 1024,
|
||||
bos_token_id: int = 1026,
|
||||
delay_pattern: Optional[list[int]] = None,
|
||||
initializer_range: float = 0.02,
|
||||
use_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(encoder_config, dict):
|
||||
encoder_config = DiaEncoderConfig(**encoder_config)
|
||||
if isinstance(decoder_config, dict):
|
||||
decoder_config = DiaDecoderConfig(**decoder_config)
|
||||
self.encoder_config = encoder_config if encoder_config is not None else DiaEncoderConfig()
|
||||
self.decoder_config = decoder_config if decoder_config is not None else DiaDecoderConfig()
|
||||
self.norm_eps = norm_eps
|
||||
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
|
||||
assert self.decoder_config.num_channels == len(self.delay_pattern), (
|
||||
"Number of channels must match delay pattern length."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_text_config(self, decoder=False):
|
||||
"""Defaulting to audio config as it's the decoder in this case which is usually the text backbone"""
|
||||
return self.decoder_config
|
||||
|
||||
|
||||
__all__ = ["DiaConfig", "DiaEncoderConfig", "DiaDecoderConfig"]
|
199
src/transformers/models/dia/convert_dia_to_hf.py
Normal file
199
src/transformers/models/dia/convert_dia_to_hf.py
Normal file
@ -0,0 +1,199 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Converts a Dia model in Nari Labs format to Hugging Face format."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from transformers import (
|
||||
DacModel,
|
||||
DiaConfig,
|
||||
DiaFeatureExtractor,
|
||||
DiaForConditionalGeneration,
|
||||
DiaProcessor,
|
||||
DiaTokenizer,
|
||||
GenerationConfig,
|
||||
)
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
# Provide just the list of layer keys you want to fix
|
||||
shape_mappings = [
|
||||
"encoder.layers.*.mlp.gate_up_proj.weight",
|
||||
"encoder.layers.*.mlp.down_proj.weight",
|
||||
"encoder.layers.*.self_attention.q_proj.weight",
|
||||
"encoder.layers.*.self_attention.k_proj.weight",
|
||||
"encoder.layers.*.self_attention.v_proj.weight",
|
||||
"encoder.layers.*.self_attention.o_proj.weight",
|
||||
"decoder.layers.*.mlp.gate_up_proj.weight",
|
||||
"decoder.layers.*.mlp.down_proj.weight",
|
||||
"decoder.layers.*.self_attention.q_proj.weight",
|
||||
"decoder.layers.*.self_attention.k_proj.weight",
|
||||
"decoder.layers.*.self_attention.v_proj.weight",
|
||||
"decoder.layers.*.self_attention.o_proj.weight",
|
||||
"decoder.layers.*.cross_attention.q_proj.weight",
|
||||
"decoder.layers.*.cross_attention.k_proj.weight",
|
||||
"decoder.layers.*.cross_attention.v_proj.weight",
|
||||
"decoder.layers.*.cross_attention.o_proj.weight",
|
||||
"decoder.logits_dense.weight",
|
||||
]
|
||||
|
||||
# Provide renamings here
|
||||
rename_mapping = {
|
||||
"mlp.wo": "mlp.down_proj",
|
||||
"mlp.wi_fused": "mlp.gate_up_proj",
|
||||
}
|
||||
|
||||
|
||||
def get_generation_config(config):
|
||||
model_generation_config = GenerationConfig.from_model_config(config)
|
||||
model_generation_config._from_model_config = False
|
||||
model_generation_config.do_sample = True
|
||||
model_generation_config.top_k = 45
|
||||
model_generation_config.top_p = 0.95
|
||||
model_generation_config.temperature = 1.2
|
||||
model_generation_config.guidance_scale = 3.0
|
||||
model_generation_config.max_length = 3072 # Decoder max length
|
||||
|
||||
return model_generation_config
|
||||
|
||||
|
||||
def convert_dia_model_to_hf(checkpoint_path, verbose=False):
|
||||
"""
|
||||
Converts a Dia model in Nari Labs format to Hugging Face format.
|
||||
Args:
|
||||
checkpoint_path (`str`):
|
||||
Path to the downloaded checkpoints.
|
||||
verbose (`bool`, *optional*)
|
||||
Whether to print information during conversion.
|
||||
"""
|
||||
# Download from HF Hub if checkpoint_path is None
|
||||
checkpoint_path = snapshot_download(repo_id=checkpoint_path, allow_patterns=["*.pth", "*.safetensors"])
|
||||
print(f"Downloaded checkpoint from Hugging Face Hub: {checkpoint_path}")
|
||||
|
||||
# Initialize base model with default config == 1.6B model
|
||||
with torch.device("meta"):
|
||||
hf_model = DiaForConditionalGeneration(config=DiaConfig())
|
||||
hf_model_dict = hf_model.state_dict()
|
||||
hf_model_keys = hf_model_dict.keys()
|
||||
|
||||
# Iterate through dir to catch all respective files - prefers safetensors but allows pt
|
||||
files = os.listdir(checkpoint_path)
|
||||
for file in files:
|
||||
if file.endswith(".safetensors"):
|
||||
load_function = load_file
|
||||
elif file.endswith(".pth"):
|
||||
load_function = torch.load
|
||||
checkpoint_path = os.path.join(checkpoint_path, files[0])
|
||||
nari_state_dict = load_function(checkpoint_path, "cpu")
|
||||
|
||||
# Conversion starts here
|
||||
converted_state_dict = {}
|
||||
embeddings = {}
|
||||
for key, tensor in nari_state_dict.items():
|
||||
# add prefix
|
||||
key = "model." + key
|
||||
|
||||
# rename some weights
|
||||
for original, rename in rename_mapping.items():
|
||||
if original in key:
|
||||
key = re.sub(original, rename, key)
|
||||
|
||||
# decoder multi channel
|
||||
if "embeddings" in key:
|
||||
embeddings_key = key.rsplit(".", 2)[0] + ".embed.weight"
|
||||
if embeddings_key in embeddings:
|
||||
embeddings[embeddings_key] += [tensor]
|
||||
else:
|
||||
embeddings[embeddings_key] = [tensor]
|
||||
continue
|
||||
elif re.sub(r"\d+", "*", key).removeprefix("model.") in shape_mappings:
|
||||
# add exception to the head
|
||||
if "logits_dense" in key:
|
||||
key = re.sub("decoder.logits_dense", "logits_dense", key).removeprefix("model.")
|
||||
|
||||
# dense general
|
||||
if key in hf_model_keys:
|
||||
tensor_shape = tensor.shape
|
||||
target_shape = hf_model_dict[key].shape
|
||||
try:
|
||||
tensor = tensor.reshape(target_shape[1], target_shape[0]).T
|
||||
if verbose:
|
||||
print(f"{key}: transpose reshaped from {tensor_shape} to {target_shape}")
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not reshape {key}: {e}")
|
||||
|
||||
converted_state_dict[key] = tensor
|
||||
|
||||
# Combining the embeddings as last step
|
||||
embeddings = {k: torch.cat(v, dim=0) for k, v in embeddings.items()}
|
||||
converted_state_dict.update(embeddings)
|
||||
|
||||
# Load converted weights into HF model
|
||||
hf_model.load_state_dict(converted_state_dict, assign=True)
|
||||
|
||||
# Overwrite generation config
|
||||
hf_model.generation_config = get_generation_config(DiaConfig())
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# # Required parameters
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default="nari-labs/Dia-1.6B", help="Path to the downloaded checkpoints"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default="AntonV/Dia-1.6B", type=str, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--convert_preprocessor",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether or not the preprocessor (tokenizer + feature extractor) should be converted along with the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether or not to log information during conversion.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = convert_dia_model_to_hf(args.checkpoint_path, args.verbose)
|
||||
if args.convert_preprocessor:
|
||||
try:
|
||||
if not _is_package_available("tiktoken"):
|
||||
raise ModuleNotFoundError(
|
||||
"""`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
else:
|
||||
processor = DiaProcessor(
|
||||
DiaFeatureExtractor(sampling_rate=44100, hop_length=512),
|
||||
DiaTokenizer(),
|
||||
DacModel.from_pretrained("descript/dac_44khz"),
|
||||
)
|
||||
processor.save_pretrained(args.pytorch_dump_folder_path)
|
||||
|
||||
model.save_pretrained(args.pytorch_dump_folder_path)
|
||||
print(f"Saved converted checkpoint to {args.pytorch_dump_folder_path}")
|
183
src/transformers/models/dia/feature_extraction_dia.py
Normal file
183
src/transformers/models/dia/feature_extraction_dia.py
Normal file
@ -0,0 +1,183 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Feature extractor class for Dia"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import PaddingStrategy, TensorType, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DiaFeatureExtractor(SequenceFeatureExtractor):
|
||||
r"""
|
||||
Constructs an Dia feature extractor.
|
||||
|
||||
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
||||
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 1):
|
||||
The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
|
||||
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||
The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
The value that is used for padding.
|
||||
hop_length (`int`, *optional*, defaults to 512):
|
||||
Overlap length between successive windows.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_values", "n_quantizers"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size: int = 1,
|
||||
sampling_rate: int = 16000,
|
||||
padding_value: float = 0.0,
|
||||
hop_length: int = 512,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_audio: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
||||
padding: Optional[Union[bool, str, PaddingStrategy]] = None,
|
||||
truncation: Optional[bool] = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several sequence(s).
|
||||
|
||||
Args:
|
||||
raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
|
||||
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
|
||||
`(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
|
||||
(`feature_size = 2`).
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
|
||||
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||
acceptable input length for the model if that argument is not provided.
|
||||
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
|
||||
lengths).
|
||||
truncation (`bool`, *optional*, defaults to `False`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
max_length (`int`, *optional*):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return Numpy `np.ndarray` objects.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors.
|
||||
"""
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||
f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
|
||||
f" {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
if padding and truncation:
|
||||
raise ValueError("Both padding and truncation were set. Make sure you only set one.")
|
||||
elif padding is None:
|
||||
# by default let's pad the inputs
|
||||
padding = True
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
|
||||
elif not is_batched and not isinstance(raw_audio, np.ndarray):
|
||||
raw_audio = np.asarray(raw_audio, dtype=np.float32)
|
||||
elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
|
||||
raw_audio = raw_audio.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
raw_audio = [np.asarray(raw_audio).T]
|
||||
|
||||
# convert stereo to mono if necessary, unique to Dia
|
||||
for idx, example in enumerate(raw_audio):
|
||||
if self.feature_size == 2 and example.ndim == 2:
|
||||
raw_audio[idx] = np.mean(example, -1)
|
||||
|
||||
# verify inputs are valid
|
||||
for idx, example in enumerate(raw_audio):
|
||||
if example.ndim > 2:
|
||||
raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
|
||||
if self.feature_size == 1 and example.ndim != 1:
|
||||
raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
|
||||
if self.feature_size == 2 and example.ndim != 1: # note the conversion before
|
||||
raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
|
||||
|
||||
input_values = BatchFeature({"input_values": raw_audio})
|
||||
|
||||
# temporarily treat it as if we were mono as we also convert stereo to mono
|
||||
origingal_feature_size = self.feature_size
|
||||
self.feature_size = 1
|
||||
|
||||
# normal padding on batch
|
||||
padded_inputs = self.pad(
|
||||
input_values,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
padding=padding,
|
||||
return_attention_mask=True,
|
||||
pad_to_multiple_of=self.hop_length,
|
||||
)
|
||||
padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
|
||||
|
||||
input_values = []
|
||||
for example in padded_inputs.pop("input_values"):
|
||||
if self.feature_size == 1:
|
||||
example = example[..., None]
|
||||
input_values.append(example.T)
|
||||
|
||||
padded_inputs["input_values"] = input_values
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
# rewrite back to original feature size
|
||||
self.feature_size = origingal_feature_size
|
||||
|
||||
return padded_inputs
|
||||
|
||||
|
||||
__all__ = ["DiaFeatureExtractor"]
|
464
src/transformers/models/dia/generation_dia.py
Normal file
464
src/transformers/models/dia/generation_dia.py
Normal file
@ -0,0 +1,464 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ...generation.logits_process import (
|
||||
DiaClassifierFreeGuidanceLogitsProcessor,
|
||||
DiaEOSChannelFilterLogitsProcessor,
|
||||
DiaEOSDelayPatternLogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
)
|
||||
from ...generation.stopping_criteria import StoppingCriteriaList
|
||||
from ...generation.streamers import BaseStreamer
|
||||
from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DiaGenerationMixin(GenerationMixin):
|
||||
# Indicates CFG which needs preparation to be properly handled by repeats
|
||||
_uses_cfg = None
|
||||
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
input_ids_seq_length: Optional[int] = None,
|
||||
encoder_input_ids: torch.LongTensor = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
device: Optional[str] = None,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> LogitsProcessorList:
|
||||
# Need either custom order or custom processor instead
|
||||
# (Temporarily disabling those for the super function)
|
||||
original_guidance_scale = generation_config.guidance_scale
|
||||
original_temperature = generation_config.temperature
|
||||
generation_config.guidance_scale = None
|
||||
generation_config.temperature = None
|
||||
|
||||
# Get base processors and those we can integrate easily
|
||||
custom_processors = LogitsProcessorList()
|
||||
|
||||
if original_temperature is not None and original_temperature != 1.0:
|
||||
custom_processors.append(TemperatureLogitsWarper(original_temperature))
|
||||
|
||||
custom_processors.append(
|
||||
DiaEOSChannelFilterLogitsProcessor(
|
||||
num_channels=len(self.config.delay_pattern),
|
||||
eos_token_id=self.config.eos_token_id,
|
||||
)
|
||||
)
|
||||
|
||||
merged_processors = super()._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=custom_processors,
|
||||
device=device,
|
||||
model_kwargs=model_kwargs,
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# Custom processors we need at specific positions
|
||||
if original_guidance_scale is not None and original_guidance_scale != 1:
|
||||
cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
|
||||
guidance_scale=original_guidance_scale,
|
||||
guidance_top_k=generation_config.top_k,
|
||||
)
|
||||
merged_processors.insert(0, cfg_processor)
|
||||
|
||||
merged_processors.append(
|
||||
DiaEOSDelayPatternLogitsProcessor(
|
||||
delay_pattern=self.config.delay_pattern,
|
||||
eos_token_id=self.config.eos_token_id,
|
||||
max_generation_len=generation_config.max_length,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
# Enable temporarily disabled values back
|
||||
generation_config.guidance_scale = original_guidance_scale
|
||||
generation_config.temperature = original_temperature
|
||||
|
||||
return merged_processors
|
||||
|
||||
def _prepare_generation_config(
|
||||
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: dict
|
||||
) -> tuple[GenerationConfig, dict]:
|
||||
generation_config, model_kwargs = super()._prepare_generation_config(
|
||||
generation_config, use_model_defaults, **kwargs
|
||||
)
|
||||
|
||||
# We allow generation up to max length + max delay pattern
|
||||
# (will revert back to max length after generation)
|
||||
generation_config.max_length += max(self.config.delay_pattern)
|
||||
|
||||
# Internal flag to indicate CFG that needs to prepare unconditioned input
|
||||
self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
def _prepare_model_inputs(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
bos_token_id: Optional[torch.Tensor] = None,
|
||||
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[str], dict[str, torch.Tensor]]:
|
||||
inputs, input_name, model_kwargs = super()._prepare_model_inputs(
|
||||
inputs=inputs,
|
||||
bos_token_id=bos_token_id,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
# If CFG is requested we fill in the unconditioned parts
|
||||
if self._uses_cfg:
|
||||
unconditioned_inputs = torch.zeros_like(inputs)
|
||||
inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
|
||||
|
||||
if model_kwargs.get("attention_mask", None) is not None:
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
|
||||
|
||||
return inputs, input_name, model_kwargs
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
batch_size: int,
|
||||
model_input_name: str,
|
||||
model_kwargs: dict[str, torch.Tensor],
|
||||
decoder_start_token_id: torch.Tensor,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
|
||||
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
|
||||
# 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
|
||||
decoder_input_ids = decoder_attention_mask = None
|
||||
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
|
||||
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
|
||||
if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
|
||||
decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
|
||||
|
||||
# We allow generating without preparation (no proper delay) but discourage it
|
||||
if decoder_input_ids is None or decoder_attention_mask is None:
|
||||
logger.warning_once(
|
||||
"In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
|
||||
f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
|
||||
f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
|
||||
)
|
||||
|
||||
num_channels = self.config.decoder_config.num_channels
|
||||
real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
|
||||
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = torch.full(
|
||||
(real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
decoder_attention_mask = torch.ones(
|
||||
size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# 2. Determine the valid input and what works as mask within the input
|
||||
delay_mask = decoder_input_ids.long()
|
||||
valid_input_size = (
|
||||
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == self.config.pad_token_id).sum(dim=-1).max()
|
||||
)
|
||||
decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
|
||||
decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
|
||||
|
||||
# 3. Overwrite into model kwargs
|
||||
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
|
||||
model_kwargs["decoder_delay_mask"] = delay_mask
|
||||
|
||||
return decoder_input_ids, model_kwargs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
encoder_outputs=None, # Using this to easily get the batch size
|
||||
decoder_delay_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
|
||||
batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
|
||||
input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
|
||||
|
||||
# Base method handles most things except CFG and the delay pattern mask
|
||||
model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
|
||||
|
||||
# Post processing for CFG and overwriting via delay pattern mask
|
||||
# 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
|
||||
model_inputs["decoder_input_ids"] = self.apply_delay_mask(
|
||||
input_ids, self.config.pad_token_id, decoder_delay_mask
|
||||
)
|
||||
|
||||
# Depending on cache usage we need to pass all or just one
|
||||
if model_inputs.get("use_cache", False) and model_inputs["cache_position"][0] > 0:
|
||||
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
|
||||
|
||||
# Be compile friendly
|
||||
model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
|
||||
|
||||
# 2. Apply CFG duplication if needed
|
||||
if self._uses_cfg:
|
||||
for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
|
||||
if model_inputs.get(key, None) is not None:
|
||||
# double first dimension and keep everything else the same
|
||||
repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
|
||||
model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if delay_mask is None:
|
||||
return input_ids
|
||||
|
||||
mask_len = min(input_ids.shape[1], delay_mask.shape[1])
|
||||
valid_mask = delay_mask[:, :mask_len, :]
|
||||
valid_input = input_ids[:, :mask_len, :]
|
||||
|
||||
# Overwrite the respective parts of the input
|
||||
input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
|
||||
|
||||
return input_ids
|
||||
|
||||
def _main_generate_loop(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_model_defaults: Optional[bool] = None,
|
||||
custom_generate: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
||||
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
||||
|
||||
generation_config, model_kwargs = self._prepare_generation_config(
|
||||
generation_config, use_model_defaults, **kwargs
|
||||
)
|
||||
self._validate_model_kwargs(model_kwargs.copy())
|
||||
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
if synced_gpus is None:
|
||||
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
|
||||
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
|
||||
# 3. Define model inputs
|
||||
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
|
||||
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||
inputs, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
device = inputs_tensor.device
|
||||
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
|
||||
|
||||
# 4. Define other model kwargs
|
||||
if "encoder_outputs" not in model_kwargs:
|
||||
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||
inputs_tensor, model_kwargs, model_input_name, generation_config
|
||||
)
|
||||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config._decoder_start_token_tensor,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
|
||||
if generation_config.token_healing:
|
||||
input_ids = self.heal_tokens(input_ids, tokenizer)
|
||||
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
# NOTE: incorrect `input_ids.shape[1]` previously
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
|
||||
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
|
||||
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
|
||||
# dynamically overrides this value as it can need more than the last token logits
|
||||
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
|
||||
model_kwargs["logits_to_keep"] = 1
|
||||
|
||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||
|
||||
# 7. Prepare the cache.
|
||||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||
max_cache_length = generation_config.max_length - 1
|
||||
if (
|
||||
inputs_tensor.shape[1] != input_ids_length
|
||||
and model_input_name == "inputs_embeds"
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
max_cache_length += inputs_tensor.shape[1]
|
||||
self._prepare_cache_for_generation(
|
||||
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
|
||||
)
|
||||
|
||||
# 8. determine generation mode
|
||||
generation_mode = generation_config.get_generation_mode(assistant_model)
|
||||
|
||||
if streamer is not None and (generation_config.num_beams > 1):
|
||||
raise ValueError(
|
||||
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
||||
)
|
||||
|
||||
# 9. prepare logits processors and stopping criteria
|
||||
prepared_logits_processor = self._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
device=inputs_tensor.device,
|
||||
model_kwargs=model_kwargs,
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
prepared_stopping_criteria = self._get_stopping_criteria(
|
||||
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
|
||||
)
|
||||
|
||||
# Set model_kwargs `use_cache` so we can use it later in forward runs
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
# ******************* taken from main generate function up to calling the different methods *******************
|
||||
|
||||
# Prepare inner 2D logic in generation loop
|
||||
input_ids = input_ids.reshape(-1, input_ids.shape[-1])
|
||||
|
||||
# 10. go into different generation modes
|
||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||
if generation_config.num_return_sequences > 1:
|
||||
raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
|
||||
|
||||
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
||||
return self._sample(
|
||||
input_ids,
|
||||
logits_processor=prepared_logits_processor,
|
||||
stopping_criteria=prepared_stopping_criteria,
|
||||
generation_config=generation_config,
|
||||
synced_gpus=synced_gpus,
|
||||
streamer=streamer,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Got incompatible mode for generation, should be one of greedy or sampling. "
|
||||
"Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_model_defaults: Optional[bool] = None,
|
||||
custom_generate: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
# We expect the initial input ids to be the complete mask (delayed input)
|
||||
delay_mask = kwargs.get("decoder_input_ids", None)
|
||||
if delay_mask is not None:
|
||||
delay_mask = delay_mask.clone()
|
||||
|
||||
output = self._main_generate_loop(
|
||||
inputs=inputs,
|
||||
generation_config=generation_config,
|
||||
logits_processor=logits_processor,
|
||||
stopping_criteria=stopping_criteria,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
synced_gpus=synced_gpus,
|
||||
assistant_model=assistant_model,
|
||||
streamer=streamer,
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
use_model_defaults=use_model_defaults,
|
||||
custom_generate=custom_generate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return_dict_in_generate = not isinstance(output, torch.Tensor)
|
||||
|
||||
if return_dict_in_generate:
|
||||
output_sequences = output.sequences
|
||||
else:
|
||||
output_sequences = output
|
||||
|
||||
# Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
|
||||
num_channels = self.config.decoder_config.num_channels
|
||||
bsz = output_sequences.shape[0] // num_channels
|
||||
output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
|
||||
|
||||
# Apply delay mask
|
||||
output_sequences = self.apply_delay_mask(output_sequences, self.config.pad_token_id, delay_mask)
|
||||
|
||||
if return_dict_in_generate:
|
||||
output.sequences = output_sequences
|
||||
else:
|
||||
output = output_sequences
|
||||
|
||||
return output
|
963
src/transformers/models/dia/modeling_dia.py
Normal file
963
src/transformers/models/dia/modeling_dia.py
Normal file
@ -0,0 +1,963 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/dia/modular_dia.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_dia.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
|
||||
from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
|
||||
from .generation_dia import DiaGenerationMixin
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DiaPreTrainedModel(PreTrainedModel):
|
||||
config_class = DiaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
main_input_name = "input_ids"
|
||||
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DiaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class DiaMultiChannelEmbedding(nn.Module):
|
||||
"""In order to efficiently compute the audio embedding from the 9 different channels,
|
||||
we vectorize the embedding process by using a single embedding layer and an offset.
|
||||
Example:
|
||||
- num_embeds = 4
|
||||
- vocab_size = 8
|
||||
- num_channels = 3
|
||||
We would have offsets = [0, 8, 16]
|
||||
If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
|
||||
then tokens = audio_codes + offsets
|
||||
= [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
|
||||
This allows us to use a single embedding layer for all channels.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiaDecoderConfig):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_channels = config.num_channels
|
||||
offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
|
||||
self.register_buffer("offsets", offsets, persistent=False)
|
||||
|
||||
def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
|
||||
tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
|
||||
embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
|
||||
return embeds.sum(dim=2)
|
||||
|
||||
|
||||
class DiaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
up_states = self.gate_up_proj(hidden_states)
|
||||
|
||||
gate, up_states = up_states.chunk(2, dim=-1)
|
||||
up_states = up_states * self.activation_fn(gate)
|
||||
|
||||
return self.down_proj(up_states)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class DiaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
DiaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class DiaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: DiaConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DiaSelfAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
|
||||
self.scaling = 1
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DiaCrossAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.cross_hidden_size = config.cross_hidden_size
|
||||
self.num_heads = self.config.cross_num_attention_heads
|
||||
self.num_key_value_heads = self.config.cross_num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.head_dim = config.cross_head_dim
|
||||
self.scaling = 1
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cross_attention_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
||||
if past_key_values is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
|
||||
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
# save all states to the cache
|
||||
key_states, value_states = past_key_values.cross_attention_cache.update(
|
||||
key_states,
|
||||
value_states,
|
||||
self.layer_idx,
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
past_key_values.is_updated[self.layer_idx] = True
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DiaEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DiaEncoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
|
||||
self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.mlp = DiaMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_sa_norm(hidden_states)
|
||||
self_attn_output, self_attn_weights = self.self_attention(
|
||||
normed_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + self_attn_output
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.post_sa_norm(hidden_states)
|
||||
mlp_out = self.mlp(normed_states)
|
||||
hidden_states = residual + mlp_out
|
||||
|
||||
return hidden_states, self_attn_weights
|
||||
|
||||
|
||||
class DiaEncoder(DiaPreTrainedModel):
|
||||
def __init__(self, config: DiaEncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.rotary_embeddings = DiaRotaryEmbedding(config)
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[BaseModelOutput, tuple]:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
|
||||
# RoPE
|
||||
# Note: We expect right padding and hence always generate
|
||||
# the position ids on the fly to reduce preparation overhead
|
||||
position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
|
||||
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
|
||||
|
||||
attention_mask = self._update_full_mask(
|
||||
attention_mask,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, None],
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
if attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class DiaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
|
||||
self.cross_attention = DiaCrossAttention(config, layer_idx)
|
||||
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.mlp = DiaMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
self_attn_cache = past_key_values
|
||||
if isinstance(self_attn_cache, EncoderDecoderCache):
|
||||
self_attn_cache = self_attn_cache.self_attention_cache
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_sa_norm(hidden_states)
|
||||
self_attn_output, self_attn_weights = self.self_attention(
|
||||
normed_states,
|
||||
position_embeddings,
|
||||
attention_mask,
|
||||
# Needs to be an arg in order to function properly
|
||||
# on inplace operations to be carried (e.g. compile)
|
||||
self_attn_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + self_attn_output
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_ca_norm(hidden_states)
|
||||
cross_states, cross_attn_weights = self.cross_attention(
|
||||
normed_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + cross_states
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_mlp_norm(hidden_states)
|
||||
mlp_out = self.mlp(normed_states)
|
||||
hidden_states = residual + mlp_out
|
||||
|
||||
return hidden_states, self_attn_weights, cross_attn_weights
|
||||
|
||||
|
||||
class DiaDecoder(DiaPreTrainedModel):
|
||||
"""Transformer Decoder Stack using DenseGeneral."""
|
||||
|
||||
def __init__(self, config: DiaDecoderConfig):
|
||||
super().__init__(config)
|
||||
self.num_channels = config.num_channels
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embeddings = DiaMultiChannelEmbedding(config)
|
||||
self.rotary_embeddings = DiaRotaryEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
|
||||
The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = input_ids.size()[:-1]
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position[None, :]
|
||||
|
||||
# RoPE
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
|
||||
|
||||
if attention_mask is None and not is_torchdynamo_compiling():
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
|
||||
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
encoder_attention_mask = self._update_cross_attn_mask(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
hidden_states.shape[:2],
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns = all_self_attns + (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
|
||||
def _update_cross_attn_mask(
|
||||
self,
|
||||
encoder_hidden_states: Union[torch.Tensor, None],
|
||||
encoder_attention_mask: Union[torch.Tensor, None],
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask,
|
||||
inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1],
|
||||
)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(encoder_attention_mask, torch.Tensor):
|
||||
encoder_attention_mask = make_flex_block_causal_mask(
|
||||
encoder_attention_mask,
|
||||
query_length=input_shape[-1],
|
||||
is_causal=False,
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
return encoder_attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The bare Dia model outputting raw hidden-states without any specific head on top.
|
||||
"""
|
||||
)
|
||||
class DiaModel(DiaPreTrainedModel):
|
||||
def __init__(self, config: DiaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.encoder = DiaEncoder(config.encoder_config)
|
||||
self.decoder = DiaDecoder(config.decoder_config)
|
||||
self.post_init()
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Seq2SeqModelOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
|
||||
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
|
||||
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
|
||||
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
|
||||
tened audio logits which are used to calculate the loss.
|
||||
|
||||
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
|
||||
Dia to calculate embeddings and subsequent steps more efficiently.
|
||||
|
||||
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
|
||||
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
|
||||
[`DiaProcessor.__call__`] for more details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
"""
|
||||
|
||||
if input_ids is None and encoder_outputs is None:
|
||||
raise ValueError(
|
||||
"You should either provide text ids or the cached text encodings. Neither has been found."
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if self.is_gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
||||
elif not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
# On default we initialize the decoder with bos tokens if nothing has been provided
|
||||
bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = torch.full(
|
||||
size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
|
||||
)
|
||||
# Ensure 3D
|
||||
if decoder_input_ids.ndim == 2:
|
||||
decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
position_ids=decoder_position_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return Seq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs[0],
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
|
||||
"""
|
||||
)
|
||||
class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: DiaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.model = DiaModel(config)
|
||||
|
||||
self.num_channels = config.decoder_config.num_channels
|
||||
self.vocab_size = config.decoder_config.vocab_size
|
||||
self.logits_dense = nn.Linear(
|
||||
config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
|
||||
)
|
||||
self.loss_type = "ForMaskedLM"
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.get_encoder()
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
|
||||
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
|
||||
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
|
||||
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
|
||||
tened audio logits which are used to calculate the loss.
|
||||
|
||||
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
|
||||
Dia to calculate embeddings and subsequent steps more efficiently.
|
||||
|
||||
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
|
||||
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
|
||||
[`DiaProcessor.__call__`] for more details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in
|
||||
`[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
|
||||
are ignored (masked).
|
||||
"""
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_position_ids=decoder_position_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
last_hidden_state = outputs[0]
|
||||
batch_size = last_hidden_state.shape[0]
|
||||
# 3D <-> 2D makes it necessary to prioritize channel dim
|
||||
audio_logits = (
|
||||
self.logits_dense(last_hidden_state)
|
||||
.view((batch_size, -1, self.num_channels, self.vocab_size))
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size * self.num_channels, -1, self.vocab_size)
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=audio_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
|
789
src/transformers/models/dia/modular_dia.py
Normal file
789
src/transformers/models/dia/modular_dia.py
Normal file
@ -0,0 +1,789 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Dia model."""
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import DynamicCache, EncoderDecoderCache
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
)
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
)
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
|
||||
from ..llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
eager_attention_forward,
|
||||
)
|
||||
from ..phi3.modeling_phi3 import Phi3MLP
|
||||
from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
|
||||
from .generation_dia import DiaGenerationMixin
|
||||
|
||||
|
||||
if is_torch_flex_attn_available():
|
||||
from ...integrations.flex_attention import make_flex_block_causal_mask
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DiaPreTrainedModel(PreTrainedModel):
|
||||
config_class = DiaConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
main_input_name = "input_ids"
|
||||
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, DiaRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
class DiaMultiChannelEmbedding(nn.Module):
|
||||
"""In order to efficiently compute the audio embedding from the 9 different channels,
|
||||
we vectorize the embedding process by using a single embedding layer and an offset.
|
||||
Example:
|
||||
- num_embeds = 4
|
||||
- vocab_size = 8
|
||||
- num_channels = 3
|
||||
We would have offsets = [0, 8, 16]
|
||||
If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
|
||||
then tokens = audio_codes + offsets
|
||||
= [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
|
||||
This allows us to use a single embedding layer for all channels.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiaDecoderConfig):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_channels = config.num_channels
|
||||
offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
|
||||
self.register_buffer("offsets", offsets, persistent=False)
|
||||
|
||||
def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
|
||||
tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
|
||||
embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
|
||||
return embeds.sum(dim=2)
|
||||
|
||||
|
||||
class DiaMLP(Phi3MLP):
|
||||
pass
|
||||
|
||||
|
||||
class DiaRMSNorm(LlamaRMSNorm):
|
||||
pass
|
||||
|
||||
|
||||
class DiaRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class DiaSelfAttention(LlamaAttention, nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Union[DiaEncoderConfig, DiaDecoderConfig], layer_idx: int, is_causal: bool = False):
|
||||
nn.Module.__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
|
||||
self.scaling = 1
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = is_causal
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
|
||||
class DiaCrossAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.cross_hidden_size = config.cross_hidden_size
|
||||
self.num_heads = self.config.cross_num_attention_heads
|
||||
self.num_key_value_heads = self.config.cross_num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.head_dim = config.cross_head_dim
|
||||
self.scaling = 1
|
||||
self.attention_dropout = 0.0
|
||||
self.is_causal = False
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cross_attention_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
||||
if past_key_values is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
|
||||
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
|
||||
else:
|
||||
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
# save all states to the cache
|
||||
key_states, value_states = past_key_values.cross_attention_cache.update(
|
||||
key_states,
|
||||
value_states,
|
||||
self.layer_idx,
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
past_key_values.is_updated[self.layer_idx] = True
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DiaEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DiaEncoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
|
||||
self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.mlp = DiaMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_sa_norm(hidden_states)
|
||||
self_attn_output, self_attn_weights = self.self_attention(
|
||||
normed_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + self_attn_output
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.post_sa_norm(hidden_states)
|
||||
mlp_out = self.mlp(normed_states)
|
||||
hidden_states = residual + mlp_out
|
||||
|
||||
return hidden_states, self_attn_weights
|
||||
|
||||
|
||||
class DiaEncoder(DiaPreTrainedModel):
|
||||
def __init__(self, config: DiaEncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList(
|
||||
[DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.rotary_embeddings = DiaRotaryEmbedding(config)
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[BaseModelOutput, tuple]:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
|
||||
# RoPE
|
||||
# Note: We expect right padding and hence always generate
|
||||
# the position ids on the fly to reduce preparation overhead
|
||||
position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
|
||||
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
|
||||
|
||||
attention_mask = self._update_full_mask(
|
||||
attention_mask,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
|
||||
def _update_full_mask(
|
||||
self,
|
||||
attention_mask: Union[torch.Tensor, None],
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
if attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(attention_mask, torch.Tensor):
|
||||
attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class DiaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: DiaDecoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
|
||||
self.cross_attention = DiaCrossAttention(config, layer_idx)
|
||||
self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
self.mlp = DiaMLP(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
self_attn_cache = past_key_values
|
||||
if isinstance(self_attn_cache, EncoderDecoderCache):
|
||||
self_attn_cache = self_attn_cache.self_attention_cache
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_sa_norm(hidden_states)
|
||||
self_attn_output, self_attn_weights = self.self_attention(
|
||||
normed_states,
|
||||
position_embeddings,
|
||||
attention_mask,
|
||||
# Needs to be an arg in order to function properly
|
||||
# on inplace operations to be carried (e.g. compile)
|
||||
self_attn_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + self_attn_output
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_ca_norm(hidden_states)
|
||||
cross_states, cross_attn_weights = self.cross_attention(
|
||||
normed_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + cross_states
|
||||
|
||||
residual = hidden_states
|
||||
normed_states = self.pre_mlp_norm(hidden_states)
|
||||
mlp_out = self.mlp(normed_states)
|
||||
hidden_states = residual + mlp_out
|
||||
|
||||
return hidden_states, self_attn_weights, cross_attn_weights
|
||||
|
||||
|
||||
class DiaDecoder(DiaPreTrainedModel):
|
||||
"""Transformer Decoder Stack using DenseGeneral."""
|
||||
|
||||
def __init__(self, config: DiaDecoderConfig):
|
||||
super().__init__(config)
|
||||
self.num_channels = config.num_channels
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embeddings = DiaMultiChannelEmbedding(config)
|
||||
self.rotary_embeddings = DiaRotaryEmbedding(config)
|
||||
self.layers = nn.ModuleList(
|
||||
[DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_hidden_states: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[BaseModelOutputWithPastAndCrossAttentions, tuple]:
|
||||
r"""
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
|
||||
The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = input_ids.size()[:-1]
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_key_values_length, past_key_values_length + seq_length, device=input_ids.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position[None, :]
|
||||
|
||||
# RoPE
|
||||
hidden_states = self.embeddings(input_ids)
|
||||
position_embeddings = self.rotary_embeddings(hidden_states, position_ids)
|
||||
|
||||
if attention_mask is None and not is_torchdynamo_compiling():
|
||||
# required mask seq length can be calculated via length of past cache
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
|
||||
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
encoder_attention_mask = self._update_cross_attn_mask(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
hidden_states.shape[:2],
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns = all_self_attns + (layer_outputs[1],)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
|
||||
def _update_cross_attn_mask(
|
||||
self,
|
||||
encoder_hidden_states: Union[torch.Tensor, None],
|
||||
encoder_attention_mask: Union[torch.Tensor, None],
|
||||
input_shape: torch.Size,
|
||||
inputs_embeds: torch.Tensor,
|
||||
):
|
||||
# expand encoder attention mask
|
||||
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
|
||||
elif self.config._attn_implementation == "sdpa":
|
||||
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
||||
encoder_attention_mask,
|
||||
inputs_embeds.dtype,
|
||||
tgt_len=input_shape[-1],
|
||||
)
|
||||
elif self.config._attn_implementation == "flex_attention":
|
||||
if isinstance(encoder_attention_mask, torch.Tensor):
|
||||
encoder_attention_mask = make_flex_block_causal_mask(
|
||||
encoder_attention_mask,
|
||||
query_length=input_shape[-1],
|
||||
is_causal=False,
|
||||
)
|
||||
else:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
encoder_attention_mask = _prepare_4d_attention_mask(
|
||||
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
return encoder_attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The bare Dia model outputting raw hidden-states without any specific head on top.
|
||||
"""
|
||||
)
|
||||
class DiaModel(DiaPreTrainedModel):
|
||||
def __init__(self, config: DiaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.encoder = DiaEncoder(config.encoder_config)
|
||||
self.decoder = DiaDecoder(config.decoder_config)
|
||||
self.post_init()
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.decoder
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Seq2SeqModelOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
|
||||
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
|
||||
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
|
||||
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
|
||||
tened audio logits which are used to calculate the loss.
|
||||
|
||||
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
|
||||
Dia to calculate embeddings and subsequent steps more efficiently.
|
||||
|
||||
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
|
||||
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
|
||||
[`DiaProcessor.__call__`] for more details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
"""
|
||||
|
||||
if input_ids is None and encoder_outputs is None:
|
||||
raise ValueError(
|
||||
"You should either provide text ids or the cached text encodings. Neither has been found."
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if self.is_gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
|
||||
elif not isinstance(encoder_outputs, BaseModelOutput):
|
||||
encoder_outputs = BaseModelOutput(
|
||||
last_hidden_state=encoder_outputs[0],
|
||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||||
)
|
||||
|
||||
# On default we initialize the decoder with bos tokens if nothing has been provided
|
||||
bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
|
||||
if decoder_input_ids is None:
|
||||
decoder_input_ids = torch.full(
|
||||
size=(bsz, 1, channels), fill_value=self.config.bos_token_id, device=self.device
|
||||
)
|
||||
# Ensure 3D
|
||||
if decoder_input_ids.ndim == 2:
|
||||
decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
|
||||
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=decoder_input_ids,
|
||||
position_ids=decoder_position_ids,
|
||||
attention_mask=decoder_attention_mask,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return Seq2SeqModelOutput(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs[0],
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
|
||||
"""
|
||||
)
|
||||
class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, config: DiaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.model = DiaModel(config)
|
||||
|
||||
self.num_channels = config.decoder_config.num_channels
|
||||
self.vocab_size = config.decoder_config.vocab_size
|
||||
self.logits_dense = nn.Linear(
|
||||
config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
|
||||
)
|
||||
self.loss_type = "ForMaskedLM"
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.get_encoder()
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_position_ids: Optional[torch.LongTensor] = None,
|
||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
||||
encoder_outputs: Optional[Union[BaseModelOutput, tuple]] = None,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple, Seq2SeqLMOutput]:
|
||||
r"""
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
|
||||
or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
|
||||
1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
|
||||
the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
|
||||
tened audio logits which are used to calculate the loss.
|
||||
|
||||
2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
|
||||
Dia to calculate embeddings and subsequent steps more efficiently.
|
||||
|
||||
If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
|
||||
`(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
|
||||
[`DiaProcessor.__call__`] for more details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in
|
||||
`[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
|
||||
are ignored (masked).
|
||||
"""
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_position_ids=decoder_position_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
last_hidden_state = outputs[0]
|
||||
batch_size = last_hidden_state.shape[0]
|
||||
# 3D <-> 2D makes it necessary to prioritize channel dim
|
||||
audio_logits = (
|
||||
self.logits_dense(last_hidden_state)
|
||||
.view((batch_size, -1, self.num_channels, self.vocab_size))
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size * self.num_channels, -1, self.vocab_size)
|
||||
)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=audio_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
decoder_hidden_states=outputs.decoder_hidden_states,
|
||||
decoder_attentions=outputs.decoder_attentions,
|
||||
cross_attentions=outputs.cross_attentions,
|
||||
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||||
encoder_hidden_states=outputs.encoder_hidden_states,
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]
|
484
src/transformers/models/dia/processing_dia.py
Normal file
484
src/transformers/models/dia/processing_dia.py
Normal file
@ -0,0 +1,484 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Processor class for Dia"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...audio_utils import AudioInput, make_list_of_audio
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...processing_utils import AudioKwargs, ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...utils import is_soundfile_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_soundfile_available():
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
class DiaAudioKwargs(AudioKwargs, total=False):
|
||||
bos_token_id: int
|
||||
eos_token_id: int
|
||||
pad_token_id: int
|
||||
delay_pattern: list[int]
|
||||
generation: bool
|
||||
|
||||
|
||||
class DiaProcessorKwargs(ProcessingKwargs, total=False):
|
||||
audio_kwargs: DiaAudioKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
"padding_side": "right",
|
||||
"add_special_tokens": False,
|
||||
},
|
||||
"audio_kwargs": {
|
||||
"eos_token_id": 1024,
|
||||
"pad_token_id": 1025,
|
||||
"bos_token_id": 1026,
|
||||
"delay_pattern": [0, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
"generation": True,
|
||||
"sampling_rate": 44100,
|
||||
},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
class DiaProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Dia processor which wraps a [`DiaFeatureExtractor`], [`DiaTokenizer`], and a [`DacModel`] into
|
||||
a single processor. It inherits, the audio feature extraction, tokenizer, and audio encode/decode functio-
|
||||
nalities. See [`~DiaProcessor.__call__`], [`~DiaProcessor.encode`], and [`~DiaProcessor.decode`] for more
|
||||
information.
|
||||
|
||||
Args:
|
||||
feature_extractor (`DiaFeatureExtractor`):
|
||||
An instance of [`DiaFeatureExtractor`]. The feature extractor is a required input.
|
||||
tokenizer (`DiaTokenizer`):
|
||||
An instance of [`DiaTokenizer`]. The tokenizer is a required input.
|
||||
audio_tokenizer (`DacModel`):
|
||||
An instance of [`DacModel`] used to encode/decode audio into/from codebooks. It is is a required input.
|
||||
"""
|
||||
|
||||
feature_extractor_class = "DiaFeatureExtractor"
|
||||
tokenizer_class = "DiaTokenizer"
|
||||
audio_tokenizer_class = "DacModel"
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer, audio_tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer, audio_tokenizer=audio_tokenizer)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
"""
|
||||
We no longer pass the raw audio values but the codebooks encoded by the `audio_tokenizer`.
|
||||
Conventions may differ between audio models due to architectural choices.
|
||||
"""
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
audio_tokenizer_input_names = ["decoder_input_ids", "decoder_attention_mask"]
|
||||
return list(dict.fromkeys(tokenizer_input_names + audio_tokenizer_input_names))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, list[str]],
|
||||
audio: Optional[AudioInput] = None,
|
||||
output_labels: Optional[bool] = False,
|
||||
**kwargs: Unpack[DiaProcessorKwargs],
|
||||
):
|
||||
"""
|
||||
Main method to prepare text(s) and audio to be fed as input to the model. The `audio` argument is
|
||||
forwarded to the DiaFeatureExtractor's [`~DiaFeatureExtractor.__call__`] and subsequently to the
|
||||
DacModel's [`~DacModel.encode`]. The `text` argument to [`~DiaTokenizer.__call__`]. Please refer
|
||||
to the docstring of the above methods for more information.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise ValueError(
|
||||
"The `DiaProcessor` relies on the `audio_tokenizer` which requires `torch` but we couldn't "
|
||||
"find it in your environment. You can install torch via `pip install torch`."
|
||||
)
|
||||
|
||||
if text is None:
|
||||
raise ValueError("You need to specify the `text` input to process.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
DiaProcessorKwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
text_kwargs = output_kwargs["text_kwargs"]
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
common_kwargs = output_kwargs["common_kwargs"]
|
||||
|
||||
return_tensors = common_kwargs.pop("return_tensors", None)
|
||||
if return_tensors != "pt":
|
||||
raise ValueError(f"{self.__class__.__name__} only supports `return_tensors='pt'`.")
|
||||
|
||||
data = {}
|
||||
|
||||
# Text
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
||||
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
||||
|
||||
encodings = self.tokenizer(text, **text_kwargs)
|
||||
data.update(encodings)
|
||||
|
||||
# Audio
|
||||
delay_pattern = audio_kwargs.pop("delay_pattern", None)
|
||||
audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
|
||||
audio_eos_token_id = audio_kwargs.pop("eos_token_id", None)
|
||||
audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
|
||||
generation = audio_kwargs.pop("generation", True)
|
||||
if (
|
||||
audio_bos_token_id is None
|
||||
or audio_eos_token_id is None
|
||||
or audio_pad_token_id is None
|
||||
or delay_pattern is None
|
||||
):
|
||||
raise ValueError(
|
||||
"To enable processing for Dia, we need the `bos_token_id`, `eos_token_id`, "
|
||||
"`pad_token_id`, and `delay_pattern`. You may have accidentally overwritten one of those."
|
||||
)
|
||||
|
||||
if generation and output_labels:
|
||||
raise ValueError(
|
||||
f"Labels with `generation` is incompatible, got generation={generation}, output_labels={output_labels}."
|
||||
)
|
||||
|
||||
batch_size = data["input_ids"].shape[0]
|
||||
num_channels = len(delay_pattern)
|
||||
max_delay = max(delay_pattern)
|
||||
|
||||
# Voice cloning generation / general training
|
||||
if audio is not None:
|
||||
audio = make_list_of_audio(audio)
|
||||
input_audios = self.feature_extractor(audio, **audio_kwargs)
|
||||
|
||||
compression_rate = math.prod(self.audio_tokenizer.config.downsampling_ratios)
|
||||
max_encoded_sequence_len = input_audios["padding_mask"][0].shape[-1] // compression_rate
|
||||
|
||||
decoder_input_ids = []
|
||||
decoder_attention_mask = []
|
||||
# TODO: dac with batching is currently broken, but non-batch is working
|
||||
# refer to https://gist.github.com/vasqu/643a45b680cf39fd7467271ee2eb6f80 for a validation script
|
||||
for padding_mask, audio in zip(input_audios["padding_mask"], input_audios["input_values"]):
|
||||
# get current length with hop length in mind (as if it were sampled as a single audio)
|
||||
base_pad_len = self.feature_extractor.hop_length
|
||||
current_audio_len = math.ceil(padding_mask.sum(dim=-1) / base_pad_len) * base_pad_len
|
||||
|
||||
encoded_sequence_len = current_audio_len // compression_rate
|
||||
padding_len = max_encoded_sequence_len - encoded_sequence_len
|
||||
|
||||
# compute non-padded forward pass; one extra bos (and eos if training) is added
|
||||
with torch.no_grad():
|
||||
audio = audio[None, ..., :current_audio_len].to(self.audio_tokenizer.device)
|
||||
input_ids = self.audio_tokenizer.encode(audio).audio_codes.transpose(1, 2)
|
||||
|
||||
if not generation:
|
||||
input_ids = torch.nn.functional.pad(
|
||||
input_ids, pad=(0, 0, 0, 1, 0, 0), mode="constant", value=audio_eos_token_id
|
||||
)
|
||||
|
||||
# apply padding
|
||||
# +1 for the bos within the real sequence
|
||||
input_ids = torch.nn.functional.pad(
|
||||
input_ids, pad=(0, 0, padding_len + 1, 0, 0, 0), mode="constant", value=audio_bos_token_id
|
||||
)
|
||||
num_valid_inputs = encoded_sequence_len + 1 + max_delay # sequence + bos + delay
|
||||
num_valid_inputs += 0 if generation else 1 # eos if training
|
||||
attention_mask = torch.tensor([0] * padding_len + [1] * num_valid_inputs, dtype=torch.long)[None, :]
|
||||
|
||||
decoder_input_ids.append(input_ids)
|
||||
decoder_attention_mask.append(attention_mask)
|
||||
|
||||
decoder_input_ids = torch.cat(decoder_input_ids, dim=0)
|
||||
decoder_attention_mask = torch.cat(decoder_attention_mask, dim=0)
|
||||
# TTS generation
|
||||
elif generation:
|
||||
# all bos to start with TTS
|
||||
decoder_input_ids = torch.full((batch_size, 1, num_channels), audio_bos_token_id, dtype=torch.long)
|
||||
|
||||
# we preemptively add the delay
|
||||
decoder_attention_mask = torch.ones(size=(batch_size, 1 + max_delay), dtype=torch.long)
|
||||
else:
|
||||
raise ValueError("If you try to train, you should provide audio data as well.")
|
||||
|
||||
if batch_size != decoder_input_ids.shape[0]:
|
||||
raise ValueError(
|
||||
f"Need the same amount of samples for both text and audio, but got text samples={batch_size} and "
|
||||
f"audio samples = {decoder_input_ids.shape[0]} instead."
|
||||
)
|
||||
|
||||
# prepare shift indices per delay
|
||||
max_seq_len = decoder_attention_mask.shape[-1]
|
||||
max_audio_len = max_seq_len - max_delay
|
||||
precomputed_idx = self.build_indices(
|
||||
bsz=batch_size,
|
||||
seq_len=max_seq_len,
|
||||
num_channels=num_channels,
|
||||
delay_pattern=delay_pattern,
|
||||
revert=False,
|
||||
)
|
||||
|
||||
# create delay pattern input
|
||||
# the pad token will be used for masking which input is valid for prediction during generation
|
||||
prefill = torch.full(
|
||||
(batch_size, max_seq_len, num_channels),
|
||||
fill_value=audio_pad_token_id,
|
||||
dtype=torch.int,
|
||||
)
|
||||
prefill[:, :max_audio_len] = decoder_input_ids
|
||||
|
||||
delayed_decoder_input_ids = self.apply_audio_delay(
|
||||
audio=prefill,
|
||||
pad_token_id=audio_pad_token_id,
|
||||
bos_token_id=audio_bos_token_id,
|
||||
precomputed_idx=precomputed_idx,
|
||||
)
|
||||
|
||||
data.update({"decoder_input_ids": delayed_decoder_input_ids, "decoder_attention_mask": decoder_attention_mask})
|
||||
|
||||
if output_labels:
|
||||
# Base idea is to shift on the sequence dim
|
||||
labels = data["decoder_input_ids"].clone()[:, 1:]
|
||||
labels[labels == audio_pad_token_id] = -100
|
||||
labels[labels == audio_bos_token_id] = -100
|
||||
|
||||
data["labels"] = labels.transpose(1, 2).reshape(batch_size * num_channels, -1).contiguous().long()
|
||||
data["decoder_input_ids"] = data["decoder_input_ids"][:, :-1]
|
||||
data["decoder_attention_mask"] = data["decoder_attention_mask"][:, :-1]
|
||||
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
def batch_decode(
|
||||
self,
|
||||
decoder_input_ids: "torch.Tensor",
|
||||
audio_prompt_len: Optional[int] = None,
|
||||
**kwargs: Unpack[DiaProcessorKwargs],
|
||||
) -> list["torch.Tensor"]:
|
||||
"""
|
||||
Decodes a batch of audio codebook sequences into their respective audio waveforms via the
|
||||
`audio_tokenizer`. See [`~DacModel.decode`] for more information.
|
||||
|
||||
Args:
|
||||
decoder_input_ids (`torch.Tensor`): The complete output sequence of the decoder.
|
||||
audio_prompt_len (`int`): The audio prefix length (e.g. when using voice cloning).
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
DiaProcessorKwargs,
|
||||
**kwargs,
|
||||
)
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
|
||||
delay_pattern = audio_kwargs.pop("delay_pattern", None)
|
||||
audio_bos_token_id = audio_kwargs.pop("bos_token_id", None)
|
||||
audio_pad_token_id = audio_kwargs.pop("pad_token_id", None)
|
||||
if audio_bos_token_id is None or audio_pad_token_id is None or delay_pattern is None:
|
||||
raise ValueError(
|
||||
"To enable decoding for Dia, we need the `bos_token_id`, `pad_token_id`, "
|
||||
"and `delay_pattern`. You may have accidentally overwritten one of those."
|
||||
)
|
||||
|
||||
# either decode the whole audio sequence or only the generated parts
|
||||
if audio_prompt_len is not None:
|
||||
audio_prompt_len = torch.tensor(audio_prompt_len, device=decoder_input_ids.device, dtype=torch.long)
|
||||
start_of_generation_idx = audio_prompt_len[None].expand(decoder_input_ids.shape[0])
|
||||
else:
|
||||
start_of_generation_idx = (decoder_input_ids[:, :, 0] == audio_bos_token_id).sum(dim=-1)
|
||||
# -1 for the eos token
|
||||
end_of_generation_idx = (
|
||||
decoder_input_ids.shape[1] - (decoder_input_ids[:, :, 0] == audio_pad_token_id).sum(dim=-1) - 1
|
||||
)
|
||||
|
||||
# revert delay
|
||||
bsz, seq_len, num_channels = decoder_input_ids.shape
|
||||
precomputed_idx = self.build_indices(
|
||||
bsz=bsz,
|
||||
seq_len=seq_len,
|
||||
num_channels=num_channels,
|
||||
delay_pattern=delay_pattern,
|
||||
revert=True,
|
||||
)
|
||||
|
||||
output_sequences = self.apply_audio_delay(
|
||||
audio=decoder_input_ids,
|
||||
# We do not care about these values as we cut them out
|
||||
# with `start_of_generation_idx` and `end_of_generation_idx`
|
||||
pad_token_id=-1,
|
||||
bos_token_id=-1,
|
||||
precomputed_idx=precomputed_idx,
|
||||
).transpose(1, 2)
|
||||
|
||||
# retrieve the correct sequences each
|
||||
audios = []
|
||||
# TODO: see above, dac doesn't work in batches yet
|
||||
with torch.no_grad():
|
||||
for i in range(start_of_generation_idx.shape[0]):
|
||||
output_i = output_sequences[i, :, start_of_generation_idx[i] : end_of_generation_idx[i]][None, ...]
|
||||
output_i = output_i.to(self.audio_tokenizer.device)
|
||||
audio_i = self.audio_tokenizer.decode(audio_codes=output_i).audio_values.cpu().squeeze()
|
||||
audios.append(audio_i)
|
||||
|
||||
return audios
|
||||
|
||||
def decode(
|
||||
self,
|
||||
decoder_input_ids: "torch.Tensor",
|
||||
audio_prompt_len: Optional[int] = None,
|
||||
**kwargs: Unpack[DiaProcessorKwargs],
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Decodes a single sequence of audio codebooks into the respective audio waveform via the
|
||||
`audio_tokenizer`. See [`~DacModel.decode`] and [`~DiaProcessor.batch_decode`] for more information.
|
||||
"""
|
||||
if decoder_input_ids.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"Expecting a single output to be decoded but received {decoder_input_ids.shape[0]} samples instead."
|
||||
)
|
||||
|
||||
return self.batch_decode(decoder_input_ids, audio_prompt_len, **kwargs)[0]
|
||||
|
||||
def get_audio_prompt_len(
|
||||
self,
|
||||
decoder_attention_mask: "torch.Tensor",
|
||||
**kwargs: Unpack[DiaProcessorKwargs],
|
||||
) -> int:
|
||||
"""Utility function to get the audio prompt length."""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
DiaProcessorKwargs,
|
||||
**kwargs,
|
||||
)
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
|
||||
delay_pattern = audio_kwargs.pop("delay_pattern", None)
|
||||
if delay_pattern is None:
|
||||
raise ValueError(
|
||||
"To enable the utility of retrieving the prompt length for Dia, we need the "
|
||||
"`delay_pattern`. You may have accidentally overwritten this."
|
||||
)
|
||||
return decoder_attention_mask.shape[1] - max(delay_pattern)
|
||||
|
||||
# Copied from transformers.models.csm.processing_csm.CsmProcessor.save_audio with Csm->Dia
|
||||
def save_audio(
|
||||
self,
|
||||
audio: AudioInput,
|
||||
saving_path: Union[str, Path, list[Union[str, Path]]],
|
||||
**kwargs: Unpack[DiaProcessorKwargs],
|
||||
):
|
||||
# TODO: @eustlb, this should be in AudioProcessor
|
||||
if not is_soundfile_available():
|
||||
raise ImportError("Please install `soundfile` to save audio files.")
|
||||
|
||||
# ensure correct audio input
|
||||
audio = make_list_of_audio(audio)
|
||||
|
||||
# ensure correct saving path
|
||||
if isinstance(saving_path, (str, Path)):
|
||||
saving_path = [saving_path]
|
||||
elif not (isinstance(saving_path, (list, tuple)) and all(isinstance(p, (str, Path)) for p in saving_path)):
|
||||
raise ValueError("Invalid input path. Please provide a string, or a list of strings")
|
||||
|
||||
if len(audio) != len(saving_path):
|
||||
raise ValueError("The number of audio and saving paths must be the same")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
DiaProcessorKwargs,
|
||||
**kwargs,
|
||||
)
|
||||
audio_kwargs = output_kwargs["audio_kwargs"]
|
||||
sampling_rate = audio_kwargs["sampling_rate"]
|
||||
|
||||
for audio_value, p in zip(audio, saving_path):
|
||||
if isinstance(audio_value, torch.Tensor):
|
||||
audio_value = audio_value.cpu().float().numpy()
|
||||
sf.write(p, audio_value, sampling_rate)
|
||||
|
||||
@staticmethod
|
||||
def build_indices(
|
||||
bsz: int,
|
||||
seq_len: int,
|
||||
num_channels: int,
|
||||
delay_pattern: list[int],
|
||||
revert: bool = False,
|
||||
) -> tuple["torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Precompute (sequence_idx, all_idx) so that out[seq, channel] = in[seq - delay[channel], channel]
|
||||
or in[seq, channel] = out[seq + delay[channel], channel] if `revert`.
|
||||
Negative sequence_idx => BOS; sequence_idx >= seq_len => PAD.
|
||||
"""
|
||||
delay_array = torch.tensor(delay_pattern, dtype=torch.int32)
|
||||
|
||||
# (0..seq_len-1)
|
||||
sequence_idx = torch.arange(seq_len, dtype=torch.int32)[None, :].expand(bsz, seq_len)[..., None]
|
||||
# + or - delay depending if we delay or revert the delay
|
||||
if not revert:
|
||||
sequence_idx = sequence_idx - delay_array[None, None, :]
|
||||
else:
|
||||
sequence_idx = sequence_idx + delay_array[None, None, :]
|
||||
# if delay goes over the range we clamp back to valid values
|
||||
valid_sequence_idx = torch.clamp(sequence_idx, 0, seq_len - 1)
|
||||
|
||||
batch_idx = torch.arange(bsz, dtype=torch.int32)[:, None, None].expand(bsz, seq_len, num_channels)
|
||||
channel_idx = torch.arange(num_channels, dtype=torch.int32)[None, None, :].expand(bsz, seq_len, num_channels)
|
||||
|
||||
all_idx = torch.stack(
|
||||
[batch_idx.reshape(-1), valid_sequence_idx.reshape(-1), channel_idx.reshape(-1)],
|
||||
dim=1,
|
||||
).long()
|
||||
|
||||
return sequence_idx, all_idx
|
||||
|
||||
@staticmethod
|
||||
def apply_audio_delay(
|
||||
audio: "torch.Tensor",
|
||||
pad_token_id: int,
|
||||
bos_token_id: int,
|
||||
precomputed_idx: tuple["torch.Tensor", "torch.Tensor"],
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Applies or reverts the delay pattern to batched audio tokens using precomputed indices,
|
||||
inserting BOS where sequence_idx < 0 and PAD where sequence_idx >= seq_len.
|
||||
|
||||
Args:
|
||||
audio: audio tokens of shape [bsz, seq_len, num_channels]
|
||||
pad_token_id: the PAD token
|
||||
bos_token_id: the BOS token
|
||||
precomputed_idx: from `build_indices`
|
||||
|
||||
Returns:
|
||||
final_audio: delayed or reverted audio tokens of shape [bsz, seq_len, num_channels]
|
||||
"""
|
||||
# Move everything to the same device
|
||||
device = audio.device
|
||||
sequence_idx, all_idx = precomputed_idx
|
||||
sequence_idx = sequence_idx.to(device)
|
||||
all_idx = all_idx.to(device)
|
||||
|
||||
# Gather per precomputed indices
|
||||
batch_idx, valid_sequence_idx, channel_idx = torch.unbind(all_idx, dim=-1)
|
||||
gathered_audio = audio[batch_idx, valid_sequence_idx, channel_idx].view(audio.size())
|
||||
|
||||
# Mask according to negative sequence_idx => BOS; sequence_idx >= seq_len => PAD
|
||||
mask_bos = sequence_idx < 0
|
||||
mask_pad = sequence_idx >= audio.shape[1]
|
||||
final_audio = torch.where(mask_bos, bos_token_id, torch.where(mask_pad, pad_token_id, gathered_audio))
|
||||
|
||||
return final_audio
|
||||
|
||||
|
||||
__all__ = ["DiaProcessor"]
|
118
src/transformers/models/dia/tokenization_dia.py
Normal file
118
src/transformers/models/dia/tokenization_dia.py
Normal file
@ -0,0 +1,118 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization class for Dia."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DiaTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a Dia tokenizer. Dia simply uses raw bytes utf-8 encoding except for special tokens `[S1]` and `[S2]`.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
unk_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
max_length (`int`, *optional*, defaults to 1024):
|
||||
The maximum length of the sequences when encoding. Sequences longer than this will be truncated.
|
||||
offset (`int`, *optional*, defaults to 0):
|
||||
The offset of the tokenizer.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pad_token: Optional[str] = "<pad>",
|
||||
unk_token: Optional[str] = "<pad>",
|
||||
max_length: Optional[int] = 1024,
|
||||
offset: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
# We have no eos/bos tokens but allow padding -- no l/r strip as we treat them as tokens as well
|
||||
pad_token = AddedToken(pad_token) if isinstance(pad_token, str) else pad_token
|
||||
unk_token = AddedToken(unk_token) if isinstance(unk_token, str) else unk_token
|
||||
|
||||
self._utf_vocab_size = 2**8 # utf is 8 bits
|
||||
self._added_tokens_decoder = {0: pad_token, 1: AddedToken("[S1]"), 2: AddedToken("[S2]")}
|
||||
self.offset = offset
|
||||
super().__init__(
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
max_length=max_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._utf_vocab_size
|
||||
|
||||
def get_vocab(self):
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||
tokens = [chr(i) for i in text.encode("utf-8")]
|
||||
return tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
|
||||
if len(token) != 1:
|
||||
token_id = None
|
||||
else:
|
||||
token_id = ord(token) + self.offset
|
||||
|
||||
return token_id
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = chr(index - self.offset)
|
||||
return token
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
bstring = b""
|
||||
for token in tokens:
|
||||
if token in self.added_tokens_decoder:
|
||||
added_token_obj = self.added_tokens_decoder[token]
|
||||
tok_string = str(added_token_obj).encode("utf-8")
|
||||
elif token in self.added_tokens_encoder:
|
||||
tok_string = token.encode("utf-8")
|
||||
else:
|
||||
tok_string = token.encode("utf-8") # Assume general string token
|
||||
bstring += tok_string
|
||||
string = bstring.decode("utf-8", errors="ignore")
|
||||
return string
|
||||
|
||||
# No vocab file
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
|
||||
return ()
|
||||
|
||||
|
||||
__all__ = ["DiaTokenizer"]
|
@ -80,15 +80,21 @@ class TextToAudioPipeline(Pipeline):
|
||||
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
|
||||
"""
|
||||
|
||||
# Introducing the processor at load time for new behaviour
|
||||
_load_processor = True
|
||||
|
||||
_pipeline_calls_generate = True
|
||||
# Make sure the docstring is updated when the default generation config is changed
|
||||
_default_generation_config = GenerationConfig(
|
||||
max_new_tokens=256,
|
||||
)
|
||||
|
||||
def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
|
||||
def __init__(self, *args, vocoder=None, sampling_rate=None, no_processor=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Legacy behaviour just uses the tokenizer while new models use the processor as a whole at any given time
|
||||
self.no_processor = no_processor
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
|
||||
|
||||
@ -117,6 +123,10 @@ class TextToAudioPipeline(Pipeline):
|
||||
if sampling_rate is not None:
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
# last fallback to get the sampling rate based on processor
|
||||
if self.sampling_rate is None and not self.no_processor and hasattr(self.processor, "feature_extractor"):
|
||||
self.sampling_rate = self.processor.feature_extractor.sampling_rate
|
||||
|
||||
def preprocess(self, text, **kwargs):
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
@ -136,7 +146,8 @@ class TextToAudioPipeline(Pipeline):
|
||||
|
||||
kwargs = new_kwargs
|
||||
|
||||
output = self.tokenizer(text, **kwargs, return_tensors="pt")
|
||||
preprocessor = self.tokenizer if self.no_processor else self.processor
|
||||
output = preprocessor(text, **kwargs, return_tensors="pt")
|
||||
|
||||
return output
|
||||
|
||||
@ -228,12 +239,21 @@ class TextToAudioPipeline(Pipeline):
|
||||
|
||||
return preprocess_params, params, postprocess_params
|
||||
|
||||
def postprocess(self, waveform):
|
||||
def postprocess(self, audio):
|
||||
output_dict = {}
|
||||
if isinstance(waveform, dict):
|
||||
waveform = waveform["waveform"]
|
||||
elif isinstance(waveform, tuple):
|
||||
waveform = waveform[0]
|
||||
|
||||
# We directly get the waveform
|
||||
if self.no_processor:
|
||||
if isinstance(audio, dict):
|
||||
waveform = audio["waveform"]
|
||||
elif isinstance(audio, tuple):
|
||||
waveform = audio[0]
|
||||
else:
|
||||
waveform = audio
|
||||
# Or we need to postprocess to get the waveform
|
||||
else:
|
||||
waveform = self.processor.decode(audio)
|
||||
|
||||
output_dict["audio"] = waveform.to(device="cpu", dtype=torch.float).numpy()
|
||||
output_dict["sampling_rate"] = self.sampling_rate
|
||||
|
||||
|
@ -49,6 +49,7 @@ from .tokenization_utils_base import (
|
||||
TruncationStrategy,
|
||||
)
|
||||
from .utils import (
|
||||
AUDIO_TOKENIZER_NAME,
|
||||
CHAT_TEMPLATE_DIR,
|
||||
CHAT_TEMPLATE_FILE,
|
||||
LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
|
||||
@ -61,12 +62,17 @@ from .utils import (
|
||||
download_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_torch_available,
|
||||
list_repo_templates,
|
||||
logging,
|
||||
)
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import PreTrainedAudioTokenizerBase
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
|
||||
@ -499,7 +505,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"""
|
||||
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
optional_attributes = ["chat_template"]
|
||||
optional_attributes = ["chat_template", "audio_tokenizer"]
|
||||
optional_call_args: list[str] = []
|
||||
# Names need to be attr_class for attr in attributes
|
||||
feature_extractor_class = None
|
||||
@ -511,7 +517,19 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# First, extract optional attributes from kwargs if present
|
||||
# Optional attributes can never be positional arguments
|
||||
for optional_attribute in self.optional_attributes:
|
||||
setattr(self, optional_attribute, kwargs.pop(optional_attribute, None))
|
||||
optional_attribute_value = kwargs.pop(optional_attribute, None)
|
||||
setattr(self, optional_attribute, optional_attribute_value)
|
||||
|
||||
# Check audio tokenizer for its class but do not treat it as attr to avoid saving weights
|
||||
if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None:
|
||||
proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value)
|
||||
|
||||
if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)):
|
||||
raise ValueError(
|
||||
f"Tried to use `{proper_class}` for audio tokenization. However, this class is not"
|
||||
" registered for audio tokenization."
|
||||
)
|
||||
|
||||
# Sanitize args and kwargs
|
||||
for key in kwargs:
|
||||
if key not in self.attributes:
|
||||
@ -530,21 +548,30 @@ class ProcessorMixin(PushToHubMixin):
|
||||
|
||||
# Check each arg is of the proper class (this will also catch a user initializing in the wrong order)
|
||||
for attribute_name, arg in kwargs.items():
|
||||
class_name = getattr(self, f"{attribute_name}_class")
|
||||
# Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
|
||||
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
|
||||
if isinstance(class_name, tuple):
|
||||
proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
|
||||
else:
|
||||
proper_class = self.get_possibly_dynamic_module(class_name)
|
||||
|
||||
if not isinstance(arg, proper_class):
|
||||
raise TypeError(
|
||||
f"Received a {type(arg).__name__} for argument {attribute_name}, but a {class_name} was expected."
|
||||
)
|
||||
|
||||
self.check_argument_for_proper_class(attribute_name, arg)
|
||||
setattr(self, attribute_name, arg)
|
||||
|
||||
def check_argument_for_proper_class(self, argument_name, argument):
|
||||
"""
|
||||
Checks the passed argument's class against the expected transformers class. In case of an unexpected
|
||||
mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class
|
||||
is returned.
|
||||
"""
|
||||
class_name = getattr(self, f"{argument_name}_class")
|
||||
# Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
|
||||
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
|
||||
if isinstance(class_name, tuple):
|
||||
proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
|
||||
else:
|
||||
proper_class = self.get_possibly_dynamic_module(class_name)
|
||||
|
||||
if not isinstance(argument, proper_class):
|
||||
raise TypeError(
|
||||
f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected."
|
||||
)
|
||||
|
||||
return proper_class
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
@ -577,6 +604,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
del output["feature_extractor"]
|
||||
if "chat_template" in output:
|
||||
del output["chat_template"]
|
||||
if "audio_tokenizer" in output:
|
||||
del output["audio_tokenizer"]
|
||||
|
||||
# Some attributes have different names but containing objects that are not simple strings
|
||||
output = {
|
||||
@ -695,6 +724,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE
|
||||
) # Legacy filename
|
||||
chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
|
||||
output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME)
|
||||
|
||||
processor_dict = self.to_dict()
|
||||
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
|
||||
@ -737,6 +767,19 @@ class ProcessorMixin(PushToHubMixin):
|
||||
"separate files using the `save_jinja_files` argument."
|
||||
)
|
||||
|
||||
if self.audio_tokenizer is not None:
|
||||
audio_tokenizer_class = self.audio_tokenizer.__class__.__name__
|
||||
audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path
|
||||
|
||||
audio_tokenizer_dict = {
|
||||
"audio_tokenizer_class": audio_tokenizer_class,
|
||||
"audio_tokenizer_name_or_path": audio_tokenizer_name_or_path,
|
||||
}
|
||||
audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(audio_tokenizer_json)
|
||||
|
||||
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
|
||||
# `auto_map` is not specified.
|
||||
if set(processor_dict.keys()) != {"processor_class"}:
|
||||
@ -774,6 +817,9 @@ class ProcessorMixin(PushToHubMixin):
|
||||
Returns:
|
||||
`tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object.
|
||||
"""
|
||||
# holding a copy for optionally loading the audio tokenizer (if available)
|
||||
audio_tokenizer_kwargs = copy.deepcopy(kwargs)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
@ -803,16 +849,18 @@ class ProcessorMixin(PushToHubMixin):
|
||||
resolved_additional_chat_template_files = {}
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
resolved_processor_file = pretrained_model_name_or_path
|
||||
# can't load chat-template when given a file as pretrained_model_name_or_path
|
||||
# can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path
|
||||
resolved_chat_template_file = None
|
||||
resolved_raw_chat_template_file = None
|
||||
resolved_audio_tokenizer_file = None
|
||||
is_local = True
|
||||
elif is_remote_url(pretrained_model_name_or_path):
|
||||
processor_file = pretrained_model_name_or_path
|
||||
resolved_processor_file = download_url(pretrained_model_name_or_path)
|
||||
# can't load chat-template when given a file url as pretrained_model_name_or_path
|
||||
# can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path
|
||||
resolved_chat_template_file = None
|
||||
resolved_raw_chat_template_file = None
|
||||
resolved_audio_tokenizer_file = None
|
||||
else:
|
||||
if is_local:
|
||||
template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
|
||||
@ -899,6 +947,21 @@ class ProcessorMixin(PushToHubMixin):
|
||||
)
|
||||
for template_name, template_file in additional_chat_template_files.items()
|
||||
}
|
||||
|
||||
resolved_audio_tokenizer_file = cached_file(
|
||||
pretrained_model_name_or_path,
|
||||
AUDIO_TOKENIZER_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
)
|
||||
except OSError:
|
||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
||||
# the original exception.
|
||||
@ -939,6 +1002,22 @@ class ProcessorMixin(PushToHubMixin):
|
||||
if chat_templates:
|
||||
kwargs["chat_template"] = chat_templates
|
||||
|
||||
# Same as chat template, adding as kwarg after loading the model
|
||||
audio_tokenizer = None
|
||||
if resolved_audio_tokenizer_file is not None:
|
||||
with open(resolved_audio_tokenizer_file, "r", encoding="utf-8") as reader:
|
||||
# The json contains the references we need to init the correct model
|
||||
audio_tokenizer_references = json.load(reader)
|
||||
audio_tokenizer_class = cls.get_possibly_dynamic_module(
|
||||
audio_tokenizer_references["audio_tokenizer_class"]
|
||||
)
|
||||
audio_tokenizer_path = audio_tokenizer_references["audio_tokenizer_name_or_path"]
|
||||
|
||||
audio_tokenizer = audio_tokenizer_class.from_pretrained(audio_tokenizer_path, **audio_tokenizer_kwargs)
|
||||
|
||||
if audio_tokenizer is not None:
|
||||
kwargs["audio_tokenizer"] = audio_tokenizer
|
||||
|
||||
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
|
||||
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
|
||||
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
|
||||
@ -947,7 +1026,9 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# In any case we need to pass `chat_template` if it is available
|
||||
processor_dict = {}
|
||||
if "chat_template" in kwargs:
|
||||
processor_dict = {"chat_template": kwargs.pop("chat_template")}
|
||||
processor_dict["chat_template"] = kwargs.pop("chat_template")
|
||||
if "audio_tokenizer" in kwargs:
|
||||
processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer")
|
||||
return processor_dict, kwargs
|
||||
|
||||
try:
|
||||
@ -972,6 +1053,8 @@ class ProcessorMixin(PushToHubMixin):
|
||||
|
||||
if "chat_template" in kwargs:
|
||||
processor_dict["chat_template"] = kwargs.pop("chat_template")
|
||||
if "audio_tokenizer" in kwargs:
|
||||
processor_dict["audio_tokenizer"] = kwargs.pop("audio_tokenizer")
|
||||
|
||||
return processor_dict, kwargs
|
||||
|
||||
@ -1276,6 +1359,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
attribute_class = cls.get_possibly_dynamic_module(class_name)
|
||||
|
||||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
|
||||
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
@ -1287,6 +1371,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
transformers_module.VIDEO_PROCESSOR_MAPPING,
|
||||
transformers_module.TOKENIZER_MAPPING,
|
||||
transformers_module.FEATURE_EXTRACTOR_MAPPING,
|
||||
transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
|
||||
]
|
||||
for lookup_location in lookup_locations:
|
||||
for custom_class in lookup_location._extra_content.values():
|
||||
|
@ -292,6 +292,7 @@ CONFIG_NAME = "config.json"
|
||||
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
||||
IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
|
||||
VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
|
||||
AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json"
|
||||
PROCESSOR_NAME = "processor_config.json"
|
||||
GENERATION_CONFIG_NAME = "generation_config.json"
|
||||
MODEL_CARD_NAME = "modelcard.json"
|
||||
|
@ -56,7 +56,12 @@ if is_torch_available():
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WatermarkLogitsProcessor,
|
||||
)
|
||||
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor
|
||||
from transformers.generation.logits_process import (
|
||||
BarkEosPrioritizerLogitsProcessor,
|
||||
DiaClassifierFreeGuidanceLogitsProcessor,
|
||||
DiaEOSChannelFilterLogitsProcessor,
|
||||
DiaEOSDelayPatternLogitsProcessor,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -1211,3 +1216,145 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
self.assertTrue(is_close)
|
||||
|
||||
def test_dia_classifier_free_guidance(self):
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
logits_uncond = torch.tensor([[1.0, 0, 1.5]])
|
||||
logits_cond = torch.tensor([[1.0, 1.0, 1.0]])
|
||||
|
||||
# base cfg with conditioned as center
|
||||
cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5)
|
||||
out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
|
||||
|
||||
res = logits_cond + 1.5 * (logits_cond - logits_uncond)
|
||||
|
||||
self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
|
||||
self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
|
||||
self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
|
||||
|
||||
# additional top k (on cond logits)
|
||||
cfg = DiaClassifierFreeGuidanceLogitsProcessor(guidance_scale=1.5, guidance_top_k=1)
|
||||
out = cfg(input_ids, torch.cat([logits_cond, logits_uncond], dim=0))
|
||||
|
||||
res = logits_cond + 1.5 * (logits_cond - logits_uncond)
|
||||
mask = res == res.max()
|
||||
res = logits_cond.clone()
|
||||
res[~mask.bool()] = -float("inf")
|
||||
|
||||
self.assertAlmostEqual(out[0, 0].item(), res[0, 0].item())
|
||||
self.assertAlmostEqual(out[0, 1].item(), res[0, 1].item())
|
||||
self.assertAlmostEqual(out[0, 2].item(), res[0, 2].item())
|
||||
|
||||
def test_dia_channel_filter(self):
|
||||
eos = 2
|
||||
bsz, channels, vocab = 2, 2, 4
|
||||
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
logits = torch.zeros(size=(bsz, channels, vocab)).view(bsz * channels, vocab)
|
||||
logits[0, eos] = 1 # Eos max (forced)
|
||||
logits[1, eos] = 1 # Eos max (forced) but not channel 0
|
||||
|
||||
channel_filter = DiaEOSChannelFilterLogitsProcessor(num_channels=channels, eos_token_id=eos)
|
||||
out = channel_filter(input_ids, logits).view(bsz, channels, vocab)
|
||||
|
||||
for i in range(vocab):
|
||||
if i > eos:
|
||||
# special tokens are not to be predicted
|
||||
self.assertTrue((out[:, :, i] == -float("inf")).all())
|
||||
elif i == eos:
|
||||
# Eos forced on channel 0
|
||||
self.assertTrue(out[0, 0, i] == 1)
|
||||
# Eos suppressed on everything else (even if max before)
|
||||
self.assertTrue(out[0, 1, i] == -float("inf"))
|
||||
self.assertTrue((out[1, :, i] == -float("inf")).all())
|
||||
else:
|
||||
# Eos forced on channel 0
|
||||
self.assertTrue(out[0, 0, i] == -float("inf"))
|
||||
# previous values
|
||||
self.assertTrue(out[0, 1, i] == 0)
|
||||
self.assertTrue((out[1, :, i] == 0).all())
|
||||
|
||||
def test_dia_delay_pattern(self):
|
||||
def check_eos_logits(out, logits, batch, channel, eos):
|
||||
for i in range(vocab):
|
||||
if i == eos:
|
||||
self.assertTrue(out[batch, channel, i] == 0)
|
||||
else:
|
||||
self.assertTrue(out[batch, channel, i] == -float("inf"))
|
||||
|
||||
for c in range(channel):
|
||||
if c != channel:
|
||||
self.assertTrue((out[batch, c] == logits[batch, c]).all())
|
||||
|
||||
eos = 2
|
||||
delay_pattern = [0, 2, 3]
|
||||
max_generation_len = 10
|
||||
bsz, channels, vocab = 2, 3, 4
|
||||
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
logits = torch.zeros(size=(bsz, channels, vocab))
|
||||
# Ensure that argmax can not result in eos
|
||||
logits[:, :, eos] = -1
|
||||
|
||||
delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
|
||||
delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
|
||||
)
|
||||
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
|
||||
|
||||
# Nothing should happen except for init of some attributes
|
||||
self.assertTrue((out == logits).all())
|
||||
self.assertTrue((~delay_pattern_processor.active_batches).all())
|
||||
self.assertTrue(
|
||||
(delay_pattern_processor.delay_pattern == torch.tensor([delay_pattern for _ in range(bsz)])).all()
|
||||
)
|
||||
|
||||
# Make first batch end
|
||||
logits[0, 0, eos] = 1
|
||||
|
||||
# Go through the complete delay pattern
|
||||
for i in range(max(delay_pattern) + 1):
|
||||
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
|
||||
|
||||
# no delay should kick in
|
||||
if i == 1:
|
||||
self.assertTrue((out == logits).all())
|
||||
else:
|
||||
j = i if i == 0 else i - 1
|
||||
check_eos_logits(out=out, logits=logits, batch=0, channel=j, eos=eos)
|
||||
self.assertTrue((out[1] == logits[1]).all())
|
||||
self.assertTrue(delay_pattern_processor.active_batches[0])
|
||||
self.assertFalse(delay_pattern_processor.active_batches[1])
|
||||
self.assertTrue(
|
||||
(
|
||||
delay_pattern_processor.delay_pattern[0]
|
||||
== torch.tensor([delay - (i + 1) for delay in delay_pattern])
|
||||
).all()
|
||||
)
|
||||
self.assertTrue((delay_pattern_processor.delay_pattern[1] == torch.tensor(delay_pattern)).all())
|
||||
|
||||
# Make second batch end
|
||||
logits[1, 0, eos] = 1
|
||||
|
||||
# Just to check if other batches could work
|
||||
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
|
||||
|
||||
self.assertTrue((out[0] == logits[0]).all())
|
||||
self.assertTrue(delay_pattern_processor.active_batches.all())
|
||||
self.assertTrue(
|
||||
(delay_pattern_processor.delay_pattern[0] == torch.tensor([delay - 5 for delay in delay_pattern])).all()
|
||||
)
|
||||
self.assertTrue(
|
||||
(delay_pattern_processor.delay_pattern[1] == torch.tensor([delay - 1 for delay in delay_pattern])).all()
|
||||
)
|
||||
|
||||
# Last check on max generation length reached (with delay in mind until last channel produces eos)
|
||||
input_ids = torch.LongTensor([[0] * (max_generation_len - max(delay_pattern) - 1)])
|
||||
delay_pattern_processor = DiaEOSDelayPatternLogitsProcessor(
|
||||
delay_pattern=delay_pattern, eos_token_id=eos, max_generation_len=max_generation_len
|
||||
)
|
||||
out = delay_pattern_processor(input_ids, logits.clone()).view(bsz, channels, vocab)
|
||||
|
||||
check_eos_logits(out=out, logits=logits, batch=0, channel=0, eos=eos)
|
||||
check_eos_logits(out=out, logits=logits, batch=1, channel=0, eos=eos)
|
||||
self.assertTrue(delay_pattern_processor.active_batches.all())
|
||||
self.assertTrue((delay_pattern_processor.delay_pattern == torch.tensor(delay_pattern) - 1).all())
|
||||
|
@ -26,6 +26,7 @@ import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
|
||||
PROCESSOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoConfig,
|
||||
@ -265,6 +266,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
|
||||
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
def test_from_pretrained_dynamic_processor_conflict(self):
|
||||
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||
@ -317,6 +320,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
|
||||
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
def test_from_pretrained_dynamic_processor_with_extra_attributes(self):
|
||||
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||
@ -356,6 +361,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
|
||||
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
def test_dynamic_processor_with_specific_dynamic_subcomponents(self):
|
||||
class NewFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||
@ -390,6 +397,8 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||
if CustomConfig in MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content:
|
||||
del MODEL_FOR_AUDIO_TOKENIZATION_MAPPING._extra_content[CustomConfig]
|
||||
|
||||
def test_auto_processor_creates_tokenizer(self):
|
||||
processor = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
0
tests/models/dia/__init__.py
Normal file
0
tests/models/dia/__init__.py
Normal file
231
tests/models/dia/test_feature_extraction_dia.py
Normal file
231
tests/models/dia/test_feature_extraction_dia.py
Normal file
@ -0,0 +1,231 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the Dia feature extractor."""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import DiaFeatureExtractor
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@require_torch
|
||||
class DiaFeatureExtractionTester:
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.__init__
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
min_seq_length=400,
|
||||
max_seq_length=2000,
|
||||
feature_size=1,
|
||||
padding_value=0.0,
|
||||
sampling_rate=16000,
|
||||
hop_length=512,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.min_seq_length = min_seq_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.hop_length = hop_length
|
||||
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
||||
self.feature_size = feature_size
|
||||
self.padding_value = padding_value
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTester.prepare_feat_extract_dict
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"feature_size": self.feature_size,
|
||||
"padding_value": self.padding_value,
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"hop_length": self.hop_length,
|
||||
}
|
||||
|
||||
# Copied from tests.models.encodec.test_feature_extraction_encodec.EnCodecFeatureExtractionTester.prepare_inputs_for_common
|
||||
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
||||
def _flatten(list_of_lists):
|
||||
return list(itertools.chain(*list_of_lists))
|
||||
|
||||
if equal_length:
|
||||
audio_inputs = floats_list((self.batch_size, self.max_seq_length))
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
audio_inputs = [
|
||||
_flatten(floats_list((x, self.feature_size)))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
]
|
||||
|
||||
if numpify:
|
||||
audio_inputs = [np.asarray(x) for x in audio_inputs]
|
||||
|
||||
return audio_inputs
|
||||
|
||||
|
||||
@require_torch
|
||||
class DiaFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = DiaFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = DiaFeatureExtractionTester(self)
|
||||
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_call
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
# create three inputs of length 800, 1000, and 1200
|
||||
audio_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
np_audio_inputs = [np.asarray(audio_input) for audio_input in audio_inputs]
|
||||
|
||||
# Test not batched input
|
||||
encoded_sequences_1 = feat_extract(audio_inputs[0], return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_audio_inputs[0], return_tensors="np").input_values
|
||||
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||
|
||||
# Test batched
|
||||
encoded_sequences_1 = feat_extract(audio_inputs, padding=True, return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_audio_inputs, padding=True, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_double_precision_pad
|
||||
def test_double_precision_pad(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
np_audio_inputs = np.random.rand(100).astype(np.float64)
|
||||
py_audio_inputs = np_audio_inputs.tolist()
|
||||
|
||||
for inputs in [py_audio_inputs, np_audio_inputs]:
|
||||
np_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="np")
|
||||
self.assertTrue(np_processed.input_values.dtype == np.float32)
|
||||
pt_processed = feature_extractor.pad([{"input_values": inputs}], return_tensors="pt")
|
||||
self.assertTrue(pt_processed.input_values.dtype == torch.float32)
|
||||
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest._load_datasamples
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
audio_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in audio_samples]
|
||||
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_integration with Dac->Dia
|
||||
def test_integration(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_VALUES = torch.tensor(
|
||||
[ 2.3803711e-03, 2.0751953e-03, 1.9836426e-03, 2.1057129e-03,
|
||||
1.6174316e-03, 3.0517578e-04, 9.1552734e-05, 3.3569336e-04,
|
||||
9.7656250e-04, 1.8310547e-03, 2.0141602e-03, 2.1057129e-03,
|
||||
1.7395020e-03, 4.5776367e-04, -3.9672852e-04, 4.5776367e-04,
|
||||
1.0070801e-03, 9.1552734e-05, 4.8828125e-04, 1.1596680e-03,
|
||||
7.3242188e-04, 9.4604492e-04, 1.8005371e-03, 1.8310547e-03,
|
||||
8.8500977e-04, 4.2724609e-04, 4.8828125e-04, 7.3242188e-04,
|
||||
1.0986328e-03, 2.1057129e-03]
|
||||
)
|
||||
# fmt: on
|
||||
input_audio = self._load_datasamples(1)
|
||||
feature_extractor = DiaFeatureExtractor()
|
||||
input_values = feature_extractor(input_audio, return_tensors="pt")["input_values"]
|
||||
self.assertEqual(input_values.shape, (1, 1, 93696))
|
||||
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
|
||||
audio_input_end = torch.tensor(input_audio[0][-30:], dtype=torch.float32)
|
||||
torch.testing.assert_close(input_values[0, 0, -46:-16], audio_input_end, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_integration_stereo(self):
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_VALUES = torch.tensor(
|
||||
[2.3804e-03, 2.0752e-03, 1.9836e-03, 2.1057e-03, 1.6174e-03,
|
||||
3.0518e-04, 9.1553e-05, 3.3569e-04, 9.7656e-04, 1.8311e-03,
|
||||
2.0142e-03, 2.1057e-03, 1.7395e-03, 4.5776e-04, -3.9673e-04,
|
||||
4.5776e-04, 1.0071e-03, 9.1553e-05, 4.8828e-04, 1.1597e-03,
|
||||
7.3242e-04, 9.4604e-04, 1.8005e-03, 1.8311e-03, 8.8501e-04,
|
||||
4.2725e-04, 4.8828e-04, 7.3242e-04, 1.0986e-03, 2.1057e-03]
|
||||
)
|
||||
# fmt: on
|
||||
input_audio = self._load_datasamples(1)
|
||||
input_audio = [np.tile(input_audio[0][None], reps=(2, 1))]
|
||||
feature_extractor = DiaFeatureExtractor(feature_size=2)
|
||||
input_values = feature_extractor(input_audio, return_tensors="pt").input_values
|
||||
self.assertEqual(input_values.shape, (1, 1, 93696))
|
||||
torch.testing.assert_close(input_values[0, 0, :30], EXPECTED_INPUT_VALUES, rtol=1e-4, atol=1e-4)
|
||||
|
||||
# Copied from tests.models.dac.test_feature_extraction_dac.DacFeatureExtractionTest.test_truncation_and_padding with Dac->Dia
|
||||
def test_truncation_and_padding(self):
|
||||
input_audio = self._load_datasamples(2)
|
||||
# would be easier if the stride was like
|
||||
feature_extractor = DiaFeatureExtractor()
|
||||
|
||||
# pad and trunc raise an error ?
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"^Both padding and truncation were set. Make sure you only set one.$",
|
||||
):
|
||||
truncated_outputs = feature_extractor(
|
||||
input_audio, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_values
|
||||
|
||||
# force truncate to max_length
|
||||
truncated_outputs = feature_extractor(
|
||||
input_audio, truncation=True, max_length=48000, return_tensors="pt"
|
||||
).input_values
|
||||
self.assertEqual(truncated_outputs.shape, (2, 1, 48128))
|
||||
|
||||
# pad:
|
||||
padded_outputs = feature_extractor(input_audio, padding=True, return_tensors="pt").input_values
|
||||
self.assertEqual(padded_outputs.shape, (2, 1, 93696))
|
||||
|
||||
# force pad to max length
|
||||
truncated_outputs = feature_extractor(
|
||||
input_audio, padding="max_length", max_length=100000, return_tensors="pt"
|
||||
).input_values
|
||||
self.assertEqual(truncated_outputs.shape, (2, 1, 100352))
|
||||
|
||||
# force no pad
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"^Unable to create tensor, you should probably activate padding with 'padding=True' to have batched tensors with the same length.$",
|
||||
):
|
||||
truncated_outputs = feature_extractor(input_audio, padding=False, return_tensors="pt").input_values
|
||||
|
||||
truncated_outputs = feature_extractor(input_audio[0], padding=False, return_tensors="pt").input_values
|
||||
self.assertEqual(truncated_outputs.shape, (1, 1, 93680))
|
752
tests/models/dia/test_modeling_dia.py
Normal file
752
tests/models/dia/test_modeling_dia.py
Normal file
@ -0,0 +1,752 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Dia model."""
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
is_flaky,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import Audio, load_dataset
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
DiaForConditionalGeneration,
|
||||
DiaModel,
|
||||
DiaProcessor,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.cache_utils import (
|
||||
Cache,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder
|
||||
|
||||
if is_torchaudio_available():
|
||||
import torchaudio
|
||||
|
||||
if is_soundfile_available():
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
@require_torch
|
||||
class DiaModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
max_length=50,
|
||||
is_training=True,
|
||||
vocab_size=100,
|
||||
hidden_size=16,
|
||||
intermediate_size=37,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
head_dim=8,
|
||||
decoder_hidden_size=32, # typically larger than encoder
|
||||
hidden_act="silu",
|
||||
eos_token_id=97, # special tokens all occur after eos
|
||||
pad_token_id=98,
|
||||
bos_token_id=99,
|
||||
delay_pattern=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.max_length = max_length
|
||||
self.is_training = is_training
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
# Set default delay pattern if not provided
|
||||
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2]
|
||||
self.num_channels = len(self.delay_pattern)
|
||||
|
||||
def get_config(self):
|
||||
encoder_config = DiaEncoderConfig(
|
||||
max_position_embeddings=self.max_length,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing
|
||||
head_dim=self.head_dim,
|
||||
intermediate_size=self.intermediate_size,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
)
|
||||
|
||||
decoder_config = DiaDecoderConfig(
|
||||
max_position_embeddings=self.max_length,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
hidden_size=self.decoder_hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_key_value_heads=1, # GQA
|
||||
head_dim=self.head_dim,
|
||||
cross_num_attention_heads=self.num_attention_heads,
|
||||
cross_head_dim=self.head_dim,
|
||||
cross_num_key_value_heads=1, # GQA
|
||||
cross_hidden_size=self.hidden_size, # match encoder hidden size
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
num_channels=self.num_channels,
|
||||
)
|
||||
|
||||
config = DiaConfig(
|
||||
encoder_config=encoder_config,
|
||||
decoder_config=decoder_config,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
delay_pattern=self.delay_pattern,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]:
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = input_ids.ne(self.pad_token_id)
|
||||
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size)
|
||||
decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id)
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]:
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = DiaModel(config=config).to(torch_device).eval()
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
|
||||
# first forward pass
|
||||
last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size)
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||
model = DiaModel(config=config).to(torch_device).eval()
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
encoder = model.get_encoder()
|
||||
encoder.save_pretrained(tmpdirname)
|
||||
encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
encoder_last_hidden_state_2 = encoder(
|
||||
input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
decoder = model.get_decoder()
|
||||
decoder.save_pretrained(tmpdirname)
|
||||
decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device)
|
||||
|
||||
last_hidden_state_2 = decoder(
|
||||
input_ids=inputs_dict["decoder_input_ids"],
|
||||
attention_mask=inputs_dict["decoder_attention_mask"],
|
||||
encoder_hidden_states=encoder_last_hidden_state,
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3)
|
||||
|
||||
|
||||
@require_torch
|
||||
class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else ()
|
||||
# We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate`
|
||||
all_generative_model_classes = (DiaForConditionalGeneration,)
|
||||
# TODO: support new pipeline behavior in tests
|
||||
pipeline_model_mapping = {}
|
||||
# pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {}
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = False
|
||||
is_encoder_decoder = True
|
||||
# Indicates VLMs usually but there are many audio models which are also composite
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DiaModelTester(self)
|
||||
# Skipping `has_text_modality` but manually testing down below
|
||||
self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig)
|
||||
self.skip_non_greedy_generate()
|
||||
|
||||
def skip_non_greedy_generate(self):
|
||||
skippable_tests = [
|
||||
"test_sample_generate_dict_output", # return sequences > 1
|
||||
"test_beam",
|
||||
"test_group_beam",
|
||||
"test_constrained_beam",
|
||||
"test_contrastive",
|
||||
"test_assisted",
|
||||
"test_dola",
|
||||
"test_prompt_lookup",
|
||||
"test_model_parallel_beam_search",
|
||||
"test_generate_without_input_ids",
|
||||
"test_generate_with_head_masking",
|
||||
]
|
||||
|
||||
for test in skippable_tests:
|
||||
if self._testMethodName.startswith(test):
|
||||
self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
"""Overriden to account for the 2D flattened structure"""
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["labels"] = torch.ones(
|
||||
(
|
||||
self.model_tester.batch_size * self.model_tester.num_channels,
|
||||
self.model_tester.seq_length,
|
||||
),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# Manual testing because of composite configs
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist")
|
||||
self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist")
|
||||
|
||||
def test_model_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
@is_flaky
|
||||
def test_encoder_decoder_model_standalone(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
||||
|
||||
# Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config
|
||||
# + additional special cases where 3D x 2D meshes confuse the expected shape
|
||||
def _check_logits(self, batch_size, logits, config):
|
||||
batch_size *= len(config.delay_pattern) # Account for flattening
|
||||
vocab_size = config.decoder_config.vocab_size
|
||||
self.assertIsInstance(logits, tuple)
|
||||
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
|
||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||
vocab_diff = vocab_size - logits[0].shape[-1]
|
||||
self.assertTrue(vocab_diff in [0, 1])
|
||||
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
|
||||
|
||||
def _check_attentions_for_generate(
|
||||
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||
):
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||
)
|
||||
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||
|
||||
use_cache = decoder_past_key_values is not None
|
||||
has_static_cache = isinstance(decoder_past_key_values, StaticCache)
|
||||
|
||||
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
|
||||
# token(s)
|
||||
for generated_length, iter_attentions in enumerate(attentions):
|
||||
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||
if use_cache and generated_length > 0:
|
||||
model_input_length = 1
|
||||
else:
|
||||
model_input_length = prompt_length + generated_length
|
||||
query_length = (
|
||||
prompt_length + generated_length
|
||||
if not has_static_cache
|
||||
else decoder_past_key_values.get_max_cache_shape()
|
||||
)
|
||||
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.decoder_config.num_attention_heads, # Decoder config
|
||||
model_input_length,
|
||||
query_length,
|
||||
)
|
||||
# check attn size
|
||||
self.assertListEqual(
|
||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||
)
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||
# Encoder config
|
||||
encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_hidden_states_for_generate(
|
||||
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||
):
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||
[True] * len(hidden_states),
|
||||
)
|
||||
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||
|
||||
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
|
||||
# new token(s)
|
||||
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
||||
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||
if use_cache and generated_length > 0:
|
||||
model_input_length = 1
|
||||
else:
|
||||
model_input_length = prompt_length + generated_length
|
||||
|
||||
# check hidden size
|
||||
# we can have different hidden sizes between encoder and decoder --> check both
|
||||
expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size)
|
||||
expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size)
|
||||
self.assertTrue(
|
||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
|
||||
== [expected_shape_encoder] * len(iter_hidden_states)
|
||||
or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
|
||||
== [expected_shape_decoder] * len(iter_hidden_states)
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||
# Encoder config
|
||||
encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
||||
|
||||
# we need the decoder config here
|
||||
config = config.decoder_config
|
||||
|
||||
# (batch, head, seq_length, head_features)
|
||||
expected_shape = (
|
||||
batch_size,
|
||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||
cache_length,
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads,
|
||||
)
|
||||
|
||||
if isinstance(decoder_past_key_values, Cache):
|
||||
self.assertListEqual(
|
||||
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||
)
|
||||
self.assertListEqual(
|
||||
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||
)
|
||||
|
||||
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||
# Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab)
|
||||
vocab_size = config.decoder_config.vocab_size
|
||||
expected_shape = (batch_size * len(config.delay_pattern), vocab_size)
|
||||
self.assertIsInstance(scores, tuple)
|
||||
self.assertEqual(len(scores), generated_length)
|
||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
"""
|
||||
Overwritten as it relies on hardcoded namings atm - checking for our case here specifically
|
||||
"""
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
sub_models_supporting_sdpa = [
|
||||
(module._supports_sdpa or module._supports_attention_backend)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
supports_sdpa_all_modules = (
|
||||
all(sub_models_supporting_sdpa)
|
||||
if len(sub_models_supporting_sdpa) > 0
|
||||
else (model._supports_sdpa or model._supports_attention_backend)
|
||||
)
|
||||
|
||||
if not supports_sdpa_all_modules:
|
||||
with self.assertRaises(ValueError):
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||
else:
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
||||
for key in model_sdpa.config:
|
||||
if isinstance(getattr(model_sdpa.config, key), PretrainedConfig):
|
||||
sub_config = getattr(model_sdpa.config, key)
|
||||
self.assertTrue(sub_config._attn_implementation == "sdpa")
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
"""Only a small change due to the expected shapes"""
|
||||
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Let's make it always:
|
||||
# 1. use cache (for obvious reasons)
|
||||
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
||||
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
||||
# continuation would force it to generate beyond an EOS token)
|
||||
# 3. ignore `token_type_ids` for simplicity
|
||||
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
||||
# active by default on some models
|
||||
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
|
||||
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
|
||||
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
|
||||
# with cache, what is considered a prompt is different in the two cases.
|
||||
|
||||
if "token_type_ids" in inputs:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
generate_kwargs = {
|
||||
"pad_token_id": -1,
|
||||
"eos_token_id": -1,
|
||||
"forced_eos_token_id": None,
|
||||
"encoder_no_repeat_ngram_size": 0,
|
||||
"use_cache": True,
|
||||
"do_sample": False,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
}
|
||||
|
||||
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
||||
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
|
||||
|
||||
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
||||
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
||||
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
|
||||
|
||||
# Continue from the tokens generated above, preparing the inputs accordingly
|
||||
inputs["past_key_values"] = outputs_cached.past_key_values
|
||||
new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test
|
||||
inputs["decoder_input_ids"] = outputs_cached.sequences
|
||||
if "decoder_attention_mask" in inputs:
|
||||
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
|
||||
inputs["decoder_attention_mask"],
|
||||
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
|
||||
mode="constant",
|
||||
value=1,
|
||||
)
|
||||
|
||||
first_caches_scores = outputs_cached.scores
|
||||
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
|
||||
full_cached_scores = first_caches_scores + outputs_cached.scores
|
||||
outputs_cached.scores = full_cached_scores
|
||||
|
||||
# The two sets of generated text and past kv should be equal to each other
|
||||
self._check_similar_generate_outputs(outputs, outputs_cached)
|
||||
for layer_idx in range(len(outputs_cached.past_key_values)):
|
||||
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.past_key_values[layer_idx][kv_idx],
|
||||
outputs_cached.past_key_values[layer_idx][kv_idx],
|
||||
)
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
|
||||
def test_past_key_values_format(self, custom_all_cache_shapes=None):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings"
|
||||
)
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
|
||||
def test_torchscript_output_hidden_state(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
|
||||
def test_torchscript_simple(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Encoder-Decoder cache can not be initialized.")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
|
||||
class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
||||
"""
|
||||
See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests
|
||||
|
||||
NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia
|
||||
(It doesn't change the behaviour as we cut by the eos token position)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# it's a dummy ckpt but should suffice for testing purposes
|
||||
self.model_checkpoint = "AntonV/Dia-1.6B"
|
||||
self.sampling_rate = 44100
|
||||
|
||||
# prepare audio
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate))
|
||||
audio_sample_1 = librispeech_dummy[-1]["audio"]["array"]
|
||||
audio_sample_2 = librispeech_dummy[-2]["audio"]["array"]
|
||||
# 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia
|
||||
dac_chunk_len = 512
|
||||
self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3"
|
||||
self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3"
|
||||
sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate)
|
||||
sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate)
|
||||
|
||||
def tearDown(self):
|
||||
pathlib.Path(self.audio_prompt_1_path).unlink()
|
||||
pathlib.Path(self.audio_prompt_2_path).unlink()
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_dia_model_integration_generate_tts(self):
|
||||
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
|
||||
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
|
||||
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
|
||||
outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026],
|
||||
[ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026],
|
||||
[ 568, 804, 10, 674, 364, 981, 728, 1026, 1026],
|
||||
[ 568, 804, 10, 674, 364, 981, 741, 550, 1026],
|
||||
[ 568, 804, 10, 674, 364, 981, 568, 378, 90],
|
||||
[1024, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 1024, 10, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 1025, 1024, 674, 364, 981, 568, 378, 731],
|
||||
[1025, 1025, 1025, 1024, 364, 981, 568, 378, 731],
|
||||
[1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731],
|
||||
[1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]],
|
||||
|
||||
[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026],
|
||||
[ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026],
|
||||
[ 592, 402, 386, 347, 362, 981, 728, 1026, 1026],
|
||||
[ 592, 402, 721, 728, 327, 981, 741, 550, 1026],
|
||||
[ 592, 402, 721, 728, 460, 62, 676, 378, 90],
|
||||
[1024, 402, 721, 728, 837, 595, 195, 982, 784],
|
||||
[1025, 402, 721, 677, 497, 102, 692, 24, 330],
|
||||
[1025, 402, 721, 677, 511, 102, 503, 871, 609],
|
||||
[1025, 402, 721, 677, 511, 96, 801, 871, 894],
|
||||
[1025, 402, 721, 677, 511, 745, 314, 498, 775],
|
||||
[1025, 402, 721, 677, 511, 745, 314, 498, 105],
|
||||
[1025, 402, 721, 677, 511, 745, 314, 861, 889],
|
||||
[1025, 893, 721, 677, 511, 744, 314, 871, 353],
|
||||
[1025, 1024, 888, 677, 511, 744, 314, 871, 332],
|
||||
[1025, 1025, 1024, 518, 511, 744, 314, 871, 366],
|
||||
[1025, 1025, 1025, 1024, 611, 744, 314, 871, 366],
|
||||
[1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366],
|
||||
[1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS)
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_dia_model_integration_generate_audio_context(self):
|
||||
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
|
||||
audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy()
|
||||
audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy()
|
||||
audio = [audio_sample_1, audio_sample_2]
|
||||
|
||||
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
|
||||
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
|
||||
|
||||
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
|
||||
# dia has right padding while we have left padding (for faster prefill)
|
||||
# additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings)
|
||||
outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False)
|
||||
outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026],
|
||||
[ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026],
|
||||
[ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026],
|
||||
[ 330, 567, 530, 753, 607, 179, 954, 242, 1026],
|
||||
[ 330, 627, 6, 1010, 500, 189, 598, 858, 247],
|
||||
[1024, 432, 480, 530, 122, 3, 788, 149, 814],
|
||||
[1025, 875, 826, 458, 98, 540, 181, 122, 608],
|
||||
[1025, 495, 840, 413, 337, 784, 591, 150, 1017],
|
||||
[1025, 808, 189, 137, 445, 0, 227, 658, 345],
|
||||
[1025, 397, 89, 753, 1016, 173, 984, 0, 910],
|
||||
[1025, 875, 460, 934, 50, 335, 670, 818, 722],
|
||||
[1025, 875, 460, 762, 119, 372, 503, 858, 584],
|
||||
[1025, 348, 555, 475, 469, 458, 963, 41, 664],
|
||||
[1025, 1024, 852, 683, 761, 193, 595, 895, 885],
|
||||
[1025, 1025, 1024, 135, 761, 902, 163, 623, 385],
|
||||
[1025, 1025, 1025, 1024, 852, 282, 581, 623, 70],
|
||||
[1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977],
|
||||
[1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
|
||||
|
||||
EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026],
|
||||
[ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026],
|
||||
[ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026],
|
||||
[ 315, 249, 678, 868, 899, 257, 950, 1026, 1026],
|
||||
[ 315, 249, 217, 471, 292, 908, 196, 469, 1026],
|
||||
[ 315, 249, 825, 771, 839, 802, 633, 590, 531],
|
||||
[1024, 249, 150, 53, 126, 76, 794, 626, 442],
|
||||
[1025, 249, 825, 218, 359, 864, 526, 626, 770],
|
||||
[1025, 249, 150, 137, 530, 845, 877, 600, 111],
|
||||
[1025, 249, 150, 287, 730, 991, 135, 259, 39],
|
||||
[1025, 249, 825, 104, 198, 1020, 719, 625, 208],
|
||||
[1025, 249, 825, 997, 602, 256, 859, 322, 518],
|
||||
[1025, 668, 825, 979, 584, 256, 98, 665, 589],
|
||||
[1025, 954, 458, 54, 206, 52, 244, 822, 599],
|
||||
[1025, 1024, 104, 914, 435, 579, 860, 92, 661],
|
||||
[1025, 1025, 1024, 848, 126, 74, 304, 92, 753],
|
||||
[1025, 1025, 1025, 1024, 362, 376, 304, 586, 753],
|
||||
[1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83],
|
||||
[1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79],
|
||||
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1)
|
||||
torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding
|
269
tests/models/dia/test_processor_dia.py
Normal file
269
tests/models/dia/test_processor_dia.py
Normal file
@ -0,0 +1,269 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import DacModel, DiaFeatureExtractor, DiaProcessor, DiaTokenizer
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available:
|
||||
import torch
|
||||
|
||||
|
||||
# Copied from tests.utils.test_modeling_utils.check_models_equal
|
||||
def check_models_equal(model1, model2):
|
||||
models_are_equal = True
|
||||
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
|
||||
if model1_p.data.ne(model2_p.data).sum() > 0:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@require_torch
|
||||
class DiaProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.checkpoint = "AntonV/Dia-1.6B"
|
||||
self.audio_tokenizer_checkpoint = "descript/dac_44khz"
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
# Audio tokenizer is a bigger model so we will reuse this if possible
|
||||
self.processor = DiaProcessor(
|
||||
tokenizer=self.get_tokenizer(),
|
||||
feature_extractor=self.get_feature_extractor(),
|
||||
audio_tokenizer=self.get_audio_tokenizer(),
|
||||
)
|
||||
|
||||
# Default audio values based on Dia and Dac
|
||||
self.pad_id = 1025
|
||||
self.bos_id = 1026
|
||||
self.dac_chunk_len = 512
|
||||
self.delay_pattern = [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return DiaTokenizer.from_pretrained(self.checkpoint, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return DiaFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
|
||||
|
||||
def get_audio_tokenizer(self, **kwargs):
|
||||
return DacModel.from_pretrained(self.audio_tokenizer_checkpoint, **kwargs)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
del self.processor
|
||||
|
||||
def test_save_load_pretrained_default(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
audio_tokenizer = self.get_audio_tokenizer()
|
||||
|
||||
processor = DiaProcessor(
|
||||
tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer
|
||||
)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
processor = DiaProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
|
||||
|
||||
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer.__class__.__name__)
|
||||
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer.name_or_path)
|
||||
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer))
|
||||
self.assertIsInstance(processor.audio_tokenizer, DacModel)
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = DiaProcessor(
|
||||
tokenizer=self.get_tokenizer(),
|
||||
feature_extractor=self.get_feature_extractor(),
|
||||
audio_tokenizer=self.get_audio_tokenizer(),
|
||||
)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
tokenizer_add_kwargs = self.get_tokenizer()
|
||||
feature_extractor_add_kwargs = self.get_feature_extractor()
|
||||
audio_tokenizer_add_kwargs = self.get_audio_tokenizer()
|
||||
|
||||
processor = DiaProcessor.from_pretrained(self.tmpdirname)
|
||||
|
||||
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
||||
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
|
||||
|
||||
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
|
||||
|
||||
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer_add_kwargs.__class__.__name__)
|
||||
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer_add_kwargs.name_or_path)
|
||||
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer_add_kwargs))
|
||||
self.assertIsInstance(processor.audio_tokenizer, DacModel)
|
||||
|
||||
def test_model_input_names(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
self.assertListEqual(
|
||||
self.processor.model_input_names,
|
||||
list(dict.fromkeys(tokenizer.model_input_names + ["decoder_input_ids", "decoder_attention_mask"])),
|
||||
msg="`processor` model input names do not match the expected names.",
|
||||
)
|
||||
|
||||
def test_tokenize(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
random_text = ["This is a processing test for tokenization", "[S1] Dia template style [S2] Nice"]
|
||||
|
||||
input_tokenizer = tokenizer(random_text, padding=True, return_tensors="pt")
|
||||
input_processor = self.processor(random_text)
|
||||
|
||||
for key in input_tokenizer.keys():
|
||||
self.assertTrue((input_tokenizer[key] == input_processor[key]).all())
|
||||
|
||||
def test_no_audio(self):
|
||||
random_text = ["Dummy Input"] * 2
|
||||
input_processor = self.processor(random_text)
|
||||
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
|
||||
|
||||
# full mask with +1 for bos
|
||||
self.assertTrue(audio_mask.sum() == (max(self.delay_pattern) + 1) * len(random_text))
|
||||
self.assertTrue(
|
||||
audio_tokens.shape
|
||||
== (
|
||||
len(random_text),
|
||||
max(self.delay_pattern) + 1,
|
||||
len(self.delay_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
for channel_idx, delay in enumerate(self.delay_pattern):
|
||||
expected_sequence = torch.ones(size=(audio_tokens.shape[:-1])) * self.pad_id
|
||||
expected_sequence[:, : delay + 1] = self.bos_id
|
||||
self.assertTrue((audio_tokens[..., channel_idx] == expected_sequence).all())
|
||||
|
||||
def test_audio(self):
|
||||
audio_tokenizer = self.get_audio_tokenizer()
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
random_text = ["Dummy Input"] * 2
|
||||
# Dac only starts accepting audio from a certain length (ensured via >=1024)
|
||||
raw_speeches = [np.random.rand(2048).astype(np.float32), np.random.rand(1024).astype(np.float32)]
|
||||
input_processor = self.processor(random_text, raw_speeches)
|
||||
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
|
||||
|
||||
sequence_len = audio_mask.shape[1]
|
||||
for batch_idx, speech in enumerate(raw_speeches):
|
||||
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
|
||||
codebooks = audio_tokenizer(raw_audio).audio_codes.transpose(1, 2)
|
||||
|
||||
pad_len = sequence_len - audio_mask.sum(dim=-1)[batch_idx]
|
||||
for channel_idx, delay in enumerate(self.delay_pattern):
|
||||
# Left padding filled bos, right padding (delay) are pad
|
||||
start_idx = pad_len + delay + 1
|
||||
end_idx = start_idx + codebooks.shape[1]
|
||||
|
||||
encoded_sequence = audio_tokens[batch_idx, :, channel_idx]
|
||||
expected_sequence = torch.ones(size=(sequence_len,)) * self.pad_id
|
||||
expected_sequence[:start_idx] = self.bos_id
|
||||
expected_sequence[start_idx:end_idx] = codebooks[0, :, channel_idx]
|
||||
|
||||
self.assertTrue((encoded_sequence == expected_sequence).all())
|
||||
|
||||
# Just to make sure the masking correctly only ignores bos tokens
|
||||
self.assertTrue((audio_tokens[~audio_mask.bool()] == self.bos_id).all())
|
||||
|
||||
@parameterized.expand([([1, 1],), ([1, 5],), ([2, 4, 6],)])
|
||||
def test_decode_audio(self, audio_lens):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
audio_tokenizer = self.get_audio_tokenizer()
|
||||
|
||||
random_text = ["Dummy Input"] * len(audio_lens)
|
||||
raw_speeches = [np.random.rand(self.dac_chunk_len * l).astype(np.float32) for l in audio_lens]
|
||||
# we need eos (given if training) to decode properly, also enforced via custom logits processor
|
||||
input_processor = self.processor(random_text, raw_speeches, generation=False)
|
||||
audio_tokens = input_processor["decoder_input_ids"]
|
||||
|
||||
decoded_speeches = self.processor.batch_decode(audio_tokens)
|
||||
for batch_idx, speech in enumerate(raw_speeches):
|
||||
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
|
||||
codebooks = audio_tokenizer(raw_audio).audio_codes
|
||||
|
||||
decoded_audio = decoded_speeches[batch_idx]
|
||||
expected_audio = audio_tokenizer.decode(audio_codes=codebooks).audio_values
|
||||
|
||||
self.assertTrue((expected_audio == decoded_audio).all())
|
||||
self.assertTrue(decoded_speeches[batch_idx].shape[-1] == audio_lens[batch_idx] * self.dac_chunk_len)
|
||||
|
||||
@parameterized.expand([(1, 2, [0, 1, 4]), (2, 4, [1, 3, 2]), (4, 8, [0, 5, 7])])
|
||||
def test_delay_in_audio(self, bsz, seq_len, delay_pattern):
|
||||
# static functions which are crucial, hence we also test them here
|
||||
build_indices_fn = DiaProcessor.build_indices
|
||||
delay_fn = DiaProcessor.apply_audio_delay
|
||||
|
||||
bos, pad = -2, -1
|
||||
num_channels = len(delay_pattern)
|
||||
|
||||
audio_input = torch.arange(bsz * seq_len * num_channels).view(bsz, seq_len, num_channels)
|
||||
# imitate a delay mask with zeroes
|
||||
audio_input = torch.cat([audio_input, torch.zeros(size=(bsz, max(delay_pattern), num_channels))], dim=1)
|
||||
|
||||
precomputed_idx = build_indices_fn(
|
||||
bsz=bsz,
|
||||
seq_len=seq_len + max(delay_pattern),
|
||||
num_channels=num_channels,
|
||||
delay_pattern=delay_pattern,
|
||||
revert=False,
|
||||
)
|
||||
delayed_audio_out = delay_fn(
|
||||
audio=audio_input,
|
||||
pad_token_id=pad,
|
||||
bos_token_id=bos,
|
||||
precomputed_idx=precomputed_idx,
|
||||
)
|
||||
|
||||
# every channel idx is shifted by delay_pattern[idx]
|
||||
delayed_audio_res = audio_input.clone()
|
||||
for idx, delay in enumerate(delay_pattern):
|
||||
delayed_audio_res[:, :delay, idx] = bos
|
||||
remaining_input = seq_len + max(delay_pattern) - delay
|
||||
delayed_audio_res[:, delay:, idx] = audio_input[:, :remaining_input, idx]
|
||||
|
||||
self.assertTrue((delayed_audio_out == delayed_audio_res).all())
|
||||
|
||||
# we should get back to the original audio we had (when removing the delay pad)
|
||||
bsz, new_seq_len, num_channels = delayed_audio_out.shape
|
||||
precomputed_idx = build_indices_fn(
|
||||
bsz=bsz,
|
||||
seq_len=new_seq_len,
|
||||
num_channels=num_channels,
|
||||
delay_pattern=delay_pattern,
|
||||
revert=True,
|
||||
)
|
||||
reverted_audio_out = delay_fn(
|
||||
audio=delayed_audio_out,
|
||||
pad_token_id=pad,
|
||||
bos_token_id=bos,
|
||||
precomputed_idx=precomputed_idx,
|
||||
)
|
||||
|
||||
reverted_audio_res = audio_input.clone()[:, :seq_len]
|
||||
|
||||
self.assertTrue((reverted_audio_out[:, :seq_len] == reverted_audio_res).all())
|
123
tests/models/dia/test_tokenization_dia.py
Normal file
123
tests/models/dia/test_tokenization_dia.py
Normal file
@ -0,0 +1,123 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.models.dia import DiaTokenizer
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
# Special tokens
|
||||
PAD = 0
|
||||
S1 = 1
|
||||
S2 = 2
|
||||
|
||||
|
||||
class DiaTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer_class = DiaTokenizer
|
||||
test_rust_tokenizer = False
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
tokenizer = DiaTokenizer()
|
||||
tokenizer.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def test_convert_token_and_id(self):
|
||||
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
|
||||
token = "i"
|
||||
token_id = 105
|
||||
|
||||
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
|
||||
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
|
||||
|
||||
def test_get_vocab(self):
|
||||
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
||||
|
||||
self.assertEqual(vocab_keys[PAD], "<pad>")
|
||||
self.assertEqual(vocab_keys[S1], "[S1]")
|
||||
self.assertEqual(vocab_keys[S2], "[S2]")
|
||||
self.assertEqual(len(vocab_keys), 256)
|
||||
|
||||
def test_vocab_size(self):
|
||||
# utf-8 == 2**8 == 256
|
||||
self.assertEqual(self.get_tokenizer().vocab_size, 256)
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = DiaTokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
tokens = tokenizer.tokenize("Hello, world!")
|
||||
self.assertListEqual(tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(ids, [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33])
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, ["H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"])
|
||||
|
||||
tokens = tokenizer.tokenize("[S1] Hello [S2] Hello<pad>")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", "<pad>"],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(ids, [S1, 32, 72, 101, 108, 108, 111, 32, S2, 32, 72, 101, 108, 108, 111, PAD])
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens, ["[S1]", " ", "H", "e", "l", "l", "o", " ", "[S2]", " ", "H", "e", "l", "l", "o", "<pad>"]
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_tokenizer_integration(self):
|
||||
# Overwritten as decoding will lead to all single bytes (i.e. characters) while usually the string format is expected
|
||||
expected_encoding = {'input_ids': [[84, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 40, 102, 111, 114, 109, 101, 114, 108, 121, 32, 107, 110, 111, 119, 110, 32, 97, 115, 32, 112, 121, 116, 111, 114, 99, 104, 45, 116, 114, 97, 110, 115, 102, 111, 114, 109, 101, 114, 115, 32, 97, 110, 100, 32, 112, 121, 116, 111, 114, 99, 104, 45, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 45, 98, 101, 114, 116, 41, 32, 112, 114, 111, 118, 105, 100, 101, 115, 32, 103, 101, 110, 101, 114, 97, 108, 45, 112, 117, 114, 112, 111, 115, 101, 32, 97, 114, 99, 104, 105, 116, 101, 99, 116, 117, 114, 101, 115, 32, 40, 66, 69, 82, 84, 44, 32, 71, 80, 84, 45, 50, 44, 32, 82, 111, 66, 69, 82, 84, 97, 44, 32, 88, 76, 77, 44, 32, 68, 105, 115, 116, 105, 108, 66, 101, 114, 116, 44, 32, 88, 76, 78, 101, 116, 46, 46, 46, 41, 32, 102, 111, 114, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 85, 110, 100, 101, 114, 115, 116, 97, 110, 100, 105, 110, 103, 32, 40, 78, 76, 85, 41, 32, 97, 110, 100, 32, 78, 97, 116, 117, 114, 97, 108, 32, 76, 97, 110, 103, 117, 97, 103, 101, 32, 71, 101, 110, 101, 114, 97, 116, 105, 111, 110, 32, 40, 78, 76, 71, 41, 32, 119, 105, 116, 104, 32, 111, 118, 101, 114, 32, 51, 50, 43, 32, 112, 114, 101, 116, 114, 97, 105, 110, 101, 100, 32, 109, 111, 100, 101, 108, 115, 32, 105, 110, 32, 49, 48, 48, 43, 32, 108, 97, 110, 103, 117, 97, 103, 101, 115, 32, 97, 110, 100, 32, 100, 101, 101, 112, 32, 105, 110, 116, 101, 114, 111, 112, 101, 114, 97, 98, 105, 108, 105, 116, 121, 32, 98, 101, 116, 119, 101, 101, 110, 32, 74, 97, 120, 44, 32, 80, 121, 84, 111, 114, 99, 104, 32, 97, 110, 100, 32, 84, 101, 110, 115, 111, 114, 70, 108, 111, 119, 46], [66, 69, 82, 84, 32, 105, 115, 32, 100, 101, 115, 105, 103, 110, 101, 100, 32, 116, 111, 32, 112, 114, 101, 45, 116, 114, 97, 105, 110, 32, 100, 101, 101, 112, 32, 98, 105, 100, 105, 114, 101, 99, 116, 105, 111, 110, 97, 108, 32, 114, 101, 112, 114, 101, 115, 101, 110, 116, 97, 116, 105, 111, 110, 115, 32, 102, 114, 111, 109, 32, 117, 110, 108, 97, 98, 101, 108, 101, 100, 32, 116, 101, 120, 116, 32, 98, 121, 32, 106, 111, 105, 110, 116, 108, 121, 32, 99, 111, 110, 100, 105, 116, 105, 111, 110, 105, 110, 103, 32, 111, 110, 32, 98, 111, 116, 104, 32, 108, 101, 102, 116, 32, 97, 110, 100, 32, 114, 105, 103, 104, 116, 32, 99, 111, 110, 116, 101, 120, 116, 32, 105, 110, 32, 97, 108, 108, 32, 108, 97, 121, 101, 114, 115, 46], [84, 104, 101, 32, 113, 117, 105, 99, 107, 32, 98, 114, 111, 119, 110, 32, 102, 111, 120, 32, 106, 117, 109, 112, 115, 32, 111, 118, 101, 114, 32, 116, 104, 101, 32, 108, 97, 122, 121, 32, 100, 111, 103, 46]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
|
||||
|
||||
sequences = [
|
||||
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
|
||||
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
|
||||
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained "
|
||||
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
|
||||
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
|
||||
"conditioning on both left and right context in all layers.",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
]
|
||||
|
||||
tokenizer_classes = [self.tokenizer_class]
|
||||
if self.test_rust_tokenizer:
|
||||
tokenizer_classes.append(self.rust_tokenizer_class)
|
||||
|
||||
for tokenizer_class in tokenizer_classes:
|
||||
tokenizer = tokenizer_class.from_pretrained("AntonV/Dia-1.6B")
|
||||
|
||||
encoding = tokenizer(sequences)
|
||||
encoding_data = encoding.data
|
||||
self.assertDictEqual(encoding_data, expected_encoding)
|
||||
|
||||
# Byte decoding leads to characters so we need to join them
|
||||
decoded_sequences = [
|
||||
"".join(tokenizer.decode(seq, skip_special_tokens=True)) for seq in encoding["input_ids"]
|
||||
]
|
||||
|
||||
for expected, decoded in zip(sequences, decoded_sequences):
|
||||
if self.test_sentencepiece_ignore_case:
|
||||
expected = expected.lower()
|
||||
self.assertEqual(expected, decoded)
|
||||
|
||||
@unittest.skip(reason="Dia relies on whole input string due to the byte-level nature.")
|
||||
def test_pretokenized_inputs(self):
|
||||
pass
|
||||
|
||||
@unittest.skip
|
||||
def test_tokenizer_slow_store_full_signature(self):
|
||||
pass
|
@ -4574,6 +4574,11 @@ class ModelTesterMixin:
|
||||
head_dim = config.head_dim
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
|
||||
cross_head_dim = None
|
||||
if hasattr(config, "cross_head_dim") and config.cross_head_dim is not None:
|
||||
cross_head_dim = config.cross_head_dim
|
||||
config.cross_head_dim = max(16, config.cross_head_dim)
|
||||
|
||||
if (
|
||||
getattr(config, "hidden_size", None) is not None
|
||||
and getattr(config, "num_attention_heads", None) is not None
|
||||
@ -4588,6 +4593,17 @@ class ModelTesterMixin:
|
||||
decoder_head_dim = config.decoder_hidden_size // config.decoder_num_attention_heads
|
||||
config.decoder_hidden_size *= max(16 // decoder_head_dim, 1)
|
||||
|
||||
if (
|
||||
getattr(config, "cross_hidden_size", None) is not None
|
||||
and getattr(config, "cross_num_attention_heads", None) is not None
|
||||
):
|
||||
cross_head_dim = (
|
||||
cross_head_dim
|
||||
if cross_head_dim is not None
|
||||
else config.cross_hidden_size // config.cross_num_attention_heads
|
||||
)
|
||||
config.cross_hidden_size *= max(16 // cross_head_dim, 1)
|
||||
|
||||
# Set default attention to flex and update config values
|
||||
update_config_for_flex(config)
|
||||
for key in config.sub_configs:
|
||||
|
@ -32,6 +32,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
SPECIAL_CASES_TO_ALLOW = {
|
||||
# used internally during generation to provide the custom logit processors with their necessary information
|
||||
"DiaConfig": [
|
||||
"delay_pattern",
|
||||
],
|
||||
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
|
||||
# periods and offsets are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
|
||||
"BambaConfig": [
|
||||
|
Loading…
Reference in New Issue
Block a user