transformers/tests/models/dia/test_modeling_dia.py
Jaeyong Sung 583db52bc6
Add Dia model (#38405)
* add dia model

* add tokenizer files

* cleanup some stuff

* brut copy paste code

* rough cleanup of the modeling code

* nuke some stuff

* more nuking

* more cleanups

* updates

* add mulitLayerEmbedding vectorization

* nits

* more modeling simplifications

* updates

* update rope

* update rope

* just fixup

* update configuration files

* more cleanup!

* default config values

* update

* forgotten comma

* another comma!

* update, more cleanups

* just more nits

* more config cleanups

* time for the encoder

* fix

* sa=mall nit

* nits

* n

* refacto a bit

* cleanup

* update cv scipt

* fix last issues

* fix last nits

* styling

* small fixes

* just run 1 generation

* fixes

* nits

* fix conversion

* fix

* more fixes

* full generate

* ouf!

* fixes!

* updates

* fix

* fix cvrt

* fixup

* nits

* delete wrong test

* update

* update

* test tokenization

* let's start changing things bit by bit - fix encoder step

* removing custom generation, moving to GenerationMixin

* add encoder decoder attention masks for generation

* mask changes, correctness checked against ad29837 in dia repo

* refactor a bit already --> next cache

* too important not to push :)

* minimal cleanup + more todos

* make main overwrite modeling utils

* add cfg filter & eos filter

* add eos countdown & delay pattern

* update eos countdown

* add max step eos countdown

* fix tests

* fix some things

* fix generation with testing

* move cfg & eos stuff to logits processor

* make RepetitionPenaltyLogitsProcessor flexible

- can accept 3D scores like (batch_size, channel, vocab)

* fix input_ids concatenation dimension in GenerationMixin for flexibility

* Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty classes; refactor logits processing in DiaForConditionalGeneration to utilize new configurations and improve flexibility.

* Add stopping criteria

* refactor

* move delay pattern from processor to modeling like musicgen.

- add docs
- change eos countdown to eos delay pattern

* fix processor & fix tests

* refactor types

* refactor imports

* format code

* fix docstring to pass ci

* add docstring to DiaConfig & add DiaModel to test

* fix docstring

* add docstring

* fix some bugs

* check

* porting / merging results from other branch - IMPORTANT: it very likely breaks generation, the goal is to have a proper forward path first

* experimental testing of left padding for first channel

* whoops

* Fix merge to make generation work

* fix cfg filter

* add position ids

* add todos, break things

* revert changes to generation --> we will force 2d but go 3d on custom stuff

* refactor a lot, change prepare decoder ids to work with left padding (needs testing), add todos

* some first fixes to get to 10. in generation

* some more generation fixes / adjustment

* style + rope fixes

* move cfg out, simplify a few things, more todos

* nit

* start working on custom logit processors

* nit

* quick fixes

* cfg top k

* more refactor of logits processing, needs a decision if gen config gets the new attributes or if we move it to config or similar

* lets keep changes to core code minimal, only eos scaling is questionable atm

* simpler eos delay logits processor

* that was for debugging :D

* proof of concept rope

* small fix on device mismatch

* cfg fixes + delay logits max len

* transformers rope

* modular dia

* more cleanup

* keep modeling consistently 3D, generate handles 2D internally

* decoder starts with bos if nothing

* post processing prototype

* style

* lol

* force sample / greedy + fixes on padding

* style

* fixup tokenization

* nits

* revert

* start working on dia tests

* fix a lot of tests

* more test fixes

* nit

* more test fixes + some features to simplify code more

* more cleanup

* forgot that one

* autodocs

* small consistency fixes

* fix regression

* small fixes

* dia feature extraction

* docs

* wip processor

* fix processor order

* processing goes brrr

* transpose before

* small fix

* fix major bug but needs now a closer look into the custom processors esp cfg

* small thing on logits

* nits

* simplify indices and shifts

* add simpler version of padding tests back (temporarily)

* add logit processor tests

* starting tests on processor

* fix mask application during generation

* some fixes on the weights conversion

* style + fixup logits order

* simplify conversion

* nit

* remove padding tests

* nits on modeling

* hmm

* fix tests

* trigger

* probably gonna be reverted, just a quick design around audio tokenizer

* fixup typing

* post merge + more typing

* initial design for audio tokenizer

* more design changes

* nit

* more processor tests and style related things

* add to init

* protect import

* not sure why tbh

* add another protect

* more fixes

* wow

* it aint stopping :D

* another missed type issue

* ...

* change design around audio tokenizer to prioritize init and go for auto - in regards to the review

* change to new causal mask function + docstrings

* change ternary

* docs

* remove todo, i dont think its essential tbh

* remove pipeline as current pipelines do not fit in the current scheme, same as csm

* closer to wrapping up the processor

* text to audio, just for demo purposes (will likely be reverted)

* check if it's this

* save audio function

* ensure no grad

* fixes on prefixed audio, hop length is used via preprocess dac, device fixes

* integration tests (tested locally on a100) + some processor utils / fixes

* style

* nits

* another round of smaller things

* docs + some fixes (generate one might be big)

* msytery solved

* small fix on conversion

* add abstract audio tokenizer, change init check to abstract class

* nits

* update docs + fix some processing :D

* change inheritance scheme for audio tokenizer

* delete dead / unnecessary code in copied generate loop

* last nits on new pipeline behavior (+ todo on tests) + style

* trigger

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Vasqu <antonprogamer@gmail.com>
2025-06-26 11:04:23 +00:00

753 lines
36 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Dia model."""
import copy
import pathlib
import tempfile
import unittest
import pytest
from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
from transformers.testing_utils import (
cleanup,
is_flaky,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
slow,
torch_device,
)
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_datasets_available():
from datasets import Audio, load_dataset
if is_torch_available():
import torch
from transformers import (
DiaForConditionalGeneration,
DiaModel,
DiaProcessor,
PretrainedConfig,
PreTrainedModel,
)
from transformers.cache_utils import (
Cache,
StaticCache,
)
from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder
if is_torchaudio_available():
import torchaudio
if is_soundfile_available():
import soundfile as sf
@require_torch
class DiaModelTester:
def __init__(
self,
parent,
batch_size=3, # need batch_size != num_hidden_layers
seq_length=7,
max_length=50,
is_training=True,
vocab_size=100,
hidden_size=16,
intermediate_size=37,
num_hidden_layers=2,
num_attention_heads=2,
head_dim=8,
decoder_hidden_size=32, # typically larger than encoder
hidden_act="silu",
eos_token_id=97, # special tokens all occur after eos
pad_token_id=98,
bos_token_id=99,
delay_pattern=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.max_length = max_length
self.is_training = is_training
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.decoder_hidden_size = decoder_hidden_size
self.hidden_act = hidden_act
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
# Set default delay pattern if not provided
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2]
self.num_channels = len(self.delay_pattern)
def get_config(self):
encoder_config = DiaEncoderConfig(
max_position_embeddings=self.max_length,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.hidden_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing
head_dim=self.head_dim,
intermediate_size=self.intermediate_size,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
)
decoder_config = DiaDecoderConfig(
max_position_embeddings=self.max_length,
num_hidden_layers=self.num_hidden_layers,
hidden_size=self.decoder_hidden_size,
intermediate_size=self.intermediate_size,
num_attention_heads=self.num_attention_heads,
num_key_value_heads=1, # GQA
head_dim=self.head_dim,
cross_num_attention_heads=self.num_attention_heads,
cross_head_dim=self.head_dim,
cross_num_key_value_heads=1, # GQA
cross_hidden_size=self.hidden_size, # match encoder hidden size
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
num_channels=self.num_channels,
)
config = DiaConfig(
encoder_config=encoder_config,
decoder_config=decoder_config,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
delay_pattern=self.delay_pattern,
)
return config
def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]:
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = input_ids.ne(self.pad_token_id)
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size)
decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id)
config = self.get_config()
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return config, inputs_dict
def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]:
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def create_and_check_model_forward(self, config, inputs_dict):
model = DiaModel(config=config).to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
# first forward pass
last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
self.parent.assertTrue(
last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size)
)
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = DiaModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(
input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3)
@require_torch
class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else ()
# We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate`
all_generative_model_classes = (DiaForConditionalGeneration,)
# TODO: support new pipeline behavior in tests
pipeline_model_mapping = {}
# pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {}
test_pruning = False
test_head_masking = False
test_resize_embeddings = False
is_encoder_decoder = True
# Indicates VLMs usually but there are many audio models which are also composite
_is_composite = True
def setUp(self):
self.model_tester = DiaModelTester(self)
# Skipping `has_text_modality` but manually testing down below
self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig)
self.skip_non_greedy_generate()
def skip_non_greedy_generate(self):
skippable_tests = [
"test_sample_generate_dict_output", # return sequences > 1
"test_beam",
"test_group_beam",
"test_constrained_beam",
"test_contrastive",
"test_assisted",
"test_dola",
"test_prompt_lookup",
"test_model_parallel_beam_search",
"test_generate_without_input_ids",
"test_generate_with_head_masking",
]
for test in skippable_tests:
if self._testMethodName.startswith(test):
self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
"""Overriden to account for the 2D flattened structure"""
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
inputs_dict["labels"] = torch.ones(
(
self.model_tester.batch_size * self.model_tester.num_channels,
self.model_tester.seq_length,
),
dtype=torch.long,
device=torch_device,
)
return inputs_dict
def test_config(self):
self.config_tester.run_common_tests()
# Manual testing because of composite configs
config = self.model_tester.prepare_config_and_inputs()[0]
self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist")
self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist")
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
@is_flaky
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config
# + additional special cases where 3D x 2D meshes confuse the expected shape
def _check_logits(self, batch_size, logits, config):
batch_size *= len(config.delay_pattern) # Account for flattening
vocab_size = config.decoder_config.vocab_size
self.assertIsInstance(logits, tuple)
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = vocab_size - logits[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
def _check_attentions_for_generate(
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
):
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (output_length - prompt_length))
use_cache = decoder_past_key_values is not None
has_static_cache = isinstance(decoder_past_key_values, StaticCache)
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
# token(s)
for generated_length, iter_attentions in enumerate(attentions):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
query_length = (
prompt_length + generated_length
if not has_static_cache
else decoder_past_key_values.get_max_cache_shape()
)
expected_shape = (
batch_size,
config.decoder_config.num_attention_heads, # Decoder config
model_input_length,
query_length,
)
# check attn size
self.assertListEqual(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
# Encoder config
encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length)
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[layer_attentions.shape for layer_attentions in attentions],
[encoder_expected_shape] * len(attentions),
)
def _check_hidden_states_for_generate(
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
):
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
[True] * len(hidden_states),
)
self.assertEqual(len(hidden_states), (output_length - prompt_length))
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
# new token(s)
for generated_length, iter_hidden_states in enumerate(hidden_states):
# regardless of using cache, the first forward pass will have the full prompt as input
if use_cache and generated_length > 0:
model_input_length = 1
else:
model_input_length = prompt_length + generated_length
# check hidden size
# we can have different hidden sizes between encoder and decoder --> check both
expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size)
expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size)
self.assertTrue(
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
== [expected_shape_encoder] * len(iter_hidden_states)
or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
== [expected_shape_decoder] * len(iter_hidden_states)
)
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
# Encoder config
encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size)
self.assertIsInstance(hidden_states, tuple)
self.assertListEqual(
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
[encoder_expected_shape] * len(hidden_states),
)
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
# we need the decoder config here
config = config.decoder_config
# (batch, head, seq_length, head_features)
expected_shape = (
batch_size,
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
cache_length,
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads,
)
if isinstance(decoder_past_key_values, Cache):
self.assertListEqual(
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
[expected_shape] * len(decoder_past_key_values.key_cache),
)
self.assertListEqual(
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
[expected_shape] * len(decoder_past_key_values.value_cache),
)
def _check_scores(self, batch_size, scores, generated_length, config):
# Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab)
vocab_size = config.decoder_config.vocab_size
expected_shape = (batch_size * len(config.delay_pattern), vocab_size)
self.assertIsInstance(scores, tuple)
self.assertEqual(len(scores), generated_length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
@require_torch_sdpa
def test_sdpa_can_dispatch_composite_models(self):
"""
Overwritten as it relies on hardcoded namings atm - checking for our case here specifically
"""
for model_class in self.all_model_classes:
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname)
sub_models_supporting_sdpa = [
(module._supports_sdpa or module._supports_attention_backend)
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
supports_sdpa_all_modules = (
all(sub_models_supporting_sdpa)
if len(sub_models_supporting_sdpa) > 0
else (model._supports_sdpa or model._supports_attention_backend)
)
if not supports_sdpa_all_modules:
with self.assertRaises(ValueError):
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
else:
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
for key in model_sdpa.config:
if isinstance(getattr(model_sdpa.config, key), PretrainedConfig):
sub_config = getattr(model_sdpa.config, key)
self.assertTrue(sub_config._attn_implementation == "sdpa")
@pytest.mark.generate
@unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.")
def test_generate_continue_from_past_key_values(self):
"""Only a small change due to the expected shapes"""
# Tests that we can continue generating from past key values, returned from a previous `generate` call
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# Let's make it always:
# 1. use cache (for obvious reasons)
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
# 3. ignore `token_type_ids` for simplicity
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
# active by default on some models
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
# with cache, what is considered a prompt is different in the two cases.
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
model = model_class(config).to(torch_device)
model.eval()
generate_kwargs = {
"pad_token_id": -1,
"eos_token_id": -1,
"forced_eos_token_id": None,
"encoder_no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_scores": True,
}
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
# inputs may need to be tweaked across `generate` calls (like the attention mask).
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
# Continue from the tokens generated above, preparing the inputs accordingly
inputs["past_key_values"] = outputs_cached.past_key_values
new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test
inputs["decoder_input_ids"] = outputs_cached.sequences
if "decoder_attention_mask" in inputs:
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
inputs["decoder_attention_mask"],
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
mode="constant",
value=1,
)
first_caches_scores = outputs_cached.scores
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
full_cached_scores = first_caches_scores + outputs_cached.scores
outputs_cached.scores = full_cached_scores
# The two sets of generated text and past kv should be equal to each other
self._check_similar_generate_outputs(outputs, outputs_cached)
for layer_idx in range(len(outputs_cached.past_key_values)):
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
outputs_cached.past_key_values[layer_idx][kv_idx],
)
)
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_past_key_values_format(self, custom_all_cache_shapes=None):
pass
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
def test_hidden_states_output(self):
pass
@unittest.skip(
reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings"
)
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
def test_torchscript_output_hidden_state(self):
pass
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
def test_torchscript_simple(self):
pass
@unittest.skip(reason="Encoder-Decoder cache can not be initialized.")
def test_multi_gpu_data_parallel_forward(self):
pass
class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
"""
See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests
NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia
(It doesn't change the behaviour as we cut by the eos token position)
"""
def setUp(self):
# it's a dummy ckpt but should suffice for testing purposes
self.model_checkpoint = "AntonV/Dia-1.6B"
self.sampling_rate = 44100
# prepare audio
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate))
audio_sample_1 = librispeech_dummy[-1]["audio"]["array"]
audio_sample_2 = librispeech_dummy[-2]["audio"]["array"]
# 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia
dac_chunk_len = 512
self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3"
self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3"
sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate)
sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate)
def tearDown(self):
pathlib.Path(self.audio_prompt_1_path).unlink()
pathlib.Path(self.audio_prompt_2_path).unlink()
cleanup(torch_device, gc_collect=True)
@slow
@require_torch_accelerator
def test_dia_model_integration_generate_tts(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026],
[ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026],
[ 568, 804, 10, 674, 364, 981, 728, 1026, 1026],
[ 568, 804, 10, 674, 364, 981, 741, 550, 1026],
[ 568, 804, 10, 674, 364, 981, 568, 378, 90],
[1024, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
[1025, 1024, 10, 674, 364, 981, 568, 378, 731],
[1025, 1025, 1024, 674, 364, 981, 568, 378, 731],
[1025, 1025, 1025, 1024, 364, 981, 568, 378, 731],
[1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]],
[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026],
[ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026],
[ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026],
[ 592, 402, 386, 347, 362, 981, 728, 1026, 1026],
[ 592, 402, 721, 728, 327, 981, 741, 550, 1026],
[ 592, 402, 721, 728, 460, 62, 676, 378, 90],
[1024, 402, 721, 728, 837, 595, 195, 982, 784],
[1025, 402, 721, 677, 497, 102, 692, 24, 330],
[1025, 402, 721, 677, 511, 102, 503, 871, 609],
[1025, 402, 721, 677, 511, 96, 801, 871, 894],
[1025, 402, 721, 677, 511, 745, 314, 498, 775],
[1025, 402, 721, 677, 511, 745, 314, 498, 105],
[1025, 402, 721, 677, 511, 745, 314, 861, 889],
[1025, 893, 721, 677, 511, 744, 314, 871, 353],
[1025, 1024, 888, 677, 511, 744, 314, 871, 332],
[1025, 1025, 1024, 518, 511, 744, 314, 871, 366],
[1025, 1025, 1025, 1024, 611, 744, 314, 871, 366],
[1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366],
[1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]])
# fmt: on
torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS)
@slow
@require_torch_accelerator
def test_dia_model_integration_generate_audio_context(self):
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy()
audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy()
audio = [audio_sample_1, audio_sample_2]
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
# dia has right padding while we have left padding (for faster prefill)
# additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings)
outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False)
outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False)
# fmt: off
EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026],
[ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026],
[ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026],
[ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026],
[ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026],
[ 330, 567, 530, 753, 607, 179, 954, 242, 1026],
[ 330, 627, 6, 1010, 500, 189, 598, 858, 247],
[1024, 432, 480, 530, 122, 3, 788, 149, 814],
[1025, 875, 826, 458, 98, 540, 181, 122, 608],
[1025, 495, 840, 413, 337, 784, 591, 150, 1017],
[1025, 808, 189, 137, 445, 0, 227, 658, 345],
[1025, 397, 89, 753, 1016, 173, 984, 0, 910],
[1025, 875, 460, 934, 50, 335, 670, 818, 722],
[1025, 875, 460, 762, 119, 372, 503, 858, 584],
[1025, 348, 555, 475, 469, 458, 963, 41, 664],
[1025, 1024, 852, 683, 761, 193, 595, 895, 885],
[1025, 1025, 1024, 135, 761, 902, 163, 623, 385],
[1025, 1025, 1025, 1024, 852, 282, 581, 623, 70],
[1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977],
[1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026],
[ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026],
[ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026],
[ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026],
[ 315, 249, 678, 868, 899, 257, 950, 1026, 1026],
[ 315, 249, 217, 471, 292, 908, 196, 469, 1026],
[ 315, 249, 825, 771, 839, 802, 633, 590, 531],
[1024, 249, 150, 53, 126, 76, 794, 626, 442],
[1025, 249, 825, 218, 359, 864, 526, 626, 770],
[1025, 249, 150, 137, 530, 845, 877, 600, 111],
[1025, 249, 150, 287, 730, 991, 135, 259, 39],
[1025, 249, 825, 104, 198, 1020, 719, 625, 208],
[1025, 249, 825, 997, 602, 256, 859, 322, 518],
[1025, 668, 825, 979, 584, 256, 98, 665, 589],
[1025, 954, 458, 54, 206, 52, 244, 822, 599],
[1025, 1024, 104, 914, 435, 579, 860, 92, 661],
[1025, 1025, 1024, 848, 126, 74, 304, 92, 753],
[1025, 1025, 1025, 1024, 362, 376, 304, 586, 753],
[1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83],
[1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928],
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79],
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
# fmt: on
torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1)
torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding