mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00

* add dia model * add tokenizer files * cleanup some stuff * brut copy paste code * rough cleanup of the modeling code * nuke some stuff * more nuking * more cleanups * updates * add mulitLayerEmbedding vectorization * nits * more modeling simplifications * updates * update rope * update rope * just fixup * update configuration files * more cleanup! * default config values * update * forgotten comma * another comma! * update, more cleanups * just more nits * more config cleanups * time for the encoder * fix * sa=mall nit * nits * n * refacto a bit * cleanup * update cv scipt * fix last issues * fix last nits * styling * small fixes * just run 1 generation * fixes * nits * fix conversion * fix * more fixes * full generate * ouf! * fixes! * updates * fix * fix cvrt * fixup * nits * delete wrong test * update * update * test tokenization * let's start changing things bit by bit - fix encoder step * removing custom generation, moving to GenerationMixin * add encoder decoder attention masks for generation * mask changes, correctness checked against ad29837 in dia repo * refactor a bit already --> next cache * too important not to push :) * minimal cleanup + more todos * make main overwrite modeling utils * add cfg filter & eos filter * add eos countdown & delay pattern * update eos countdown * add max step eos countdown * fix tests * fix some things * fix generation with testing * move cfg & eos stuff to logits processor * make RepetitionPenaltyLogitsProcessor flexible - can accept 3D scores like (batch_size, channel, vocab) * fix input_ids concatenation dimension in GenerationMixin for flexibility * Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty classes; refactor logits processing in DiaForConditionalGeneration to utilize new configurations and improve flexibility. * Add stopping criteria * refactor * move delay pattern from processor to modeling like musicgen. - add docs - change eos countdown to eos delay pattern * fix processor & fix tests * refactor types * refactor imports * format code * fix docstring to pass ci * add docstring to DiaConfig & add DiaModel to test * fix docstring * add docstring * fix some bugs * check * porting / merging results from other branch - IMPORTANT: it very likely breaks generation, the goal is to have a proper forward path first * experimental testing of left padding for first channel * whoops * Fix merge to make generation work * fix cfg filter * add position ids * add todos, break things * revert changes to generation --> we will force 2d but go 3d on custom stuff * refactor a lot, change prepare decoder ids to work with left padding (needs testing), add todos * some first fixes to get to 10. in generation * some more generation fixes / adjustment * style + rope fixes * move cfg out, simplify a few things, more todos * nit * start working on custom logit processors * nit * quick fixes * cfg top k * more refactor of logits processing, needs a decision if gen config gets the new attributes or if we move it to config or similar * lets keep changes to core code minimal, only eos scaling is questionable atm * simpler eos delay logits processor * that was for debugging :D * proof of concept rope * small fix on device mismatch * cfg fixes + delay logits max len * transformers rope * modular dia * more cleanup * keep modeling consistently 3D, generate handles 2D internally * decoder starts with bos if nothing * post processing prototype * style * lol * force sample / greedy + fixes on padding * style * fixup tokenization * nits * revert * start working on dia tests * fix a lot of tests * more test fixes * nit * more test fixes + some features to simplify code more * more cleanup * forgot that one * autodocs * small consistency fixes * fix regression * small fixes * dia feature extraction * docs * wip processor * fix processor order * processing goes brrr * transpose before * small fix * fix major bug but needs now a closer look into the custom processors esp cfg * small thing on logits * nits * simplify indices and shifts * add simpler version of padding tests back (temporarily) * add logit processor tests * starting tests on processor * fix mask application during generation * some fixes on the weights conversion * style + fixup logits order * simplify conversion * nit * remove padding tests * nits on modeling * hmm * fix tests * trigger * probably gonna be reverted, just a quick design around audio tokenizer * fixup typing * post merge + more typing * initial design for audio tokenizer * more design changes * nit * more processor tests and style related things * add to init * protect import * not sure why tbh * add another protect * more fixes * wow * it aint stopping :D * another missed type issue * ... * change design around audio tokenizer to prioritize init and go for auto - in regards to the review * change to new causal mask function + docstrings * change ternary * docs * remove todo, i dont think its essential tbh * remove pipeline as current pipelines do not fit in the current scheme, same as csm * closer to wrapping up the processor * text to audio, just for demo purposes (will likely be reverted) * check if it's this * save audio function * ensure no grad * fixes on prefixed audio, hop length is used via preprocess dac, device fixes * integration tests (tested locally on a100) + some processor utils / fixes * style * nits * another round of smaller things * docs + some fixes (generate one might be big) * msytery solved * small fix on conversion * add abstract audio tokenizer, change init check to abstract class * nits * update docs + fix some processing :D * change inheritance scheme for audio tokenizer * delete dead / unnecessary code in copied generate loop * last nits on new pipeline behavior (+ todo on tests) + style * trigger --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Vasqu <antonprogamer@gmail.com>
753 lines
36 KiB
Python
753 lines
36 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch Dia model."""
|
|
|
|
import copy
|
|
import pathlib
|
|
import tempfile
|
|
import unittest
|
|
|
|
import pytest
|
|
|
|
from transformers.models.dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
|
|
from transformers.testing_utils import (
|
|
cleanup,
|
|
is_flaky,
|
|
require_torch,
|
|
require_torch_accelerator,
|
|
require_torch_sdpa,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import is_soundfile_available, is_torch_available, is_torchaudio_available
|
|
from transformers.utils.import_utils import is_datasets_available
|
|
|
|
from ...generation.test_utils import GenerationTesterMixin
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_datasets_available():
|
|
from datasets import Audio, load_dataset
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
from transformers import (
|
|
DiaForConditionalGeneration,
|
|
DiaModel,
|
|
DiaProcessor,
|
|
PretrainedConfig,
|
|
PreTrainedModel,
|
|
)
|
|
from transformers.cache_utils import (
|
|
Cache,
|
|
StaticCache,
|
|
)
|
|
from transformers.models.dia.modeling_dia import DiaDecoder, DiaEncoder
|
|
|
|
if is_torchaudio_available():
|
|
import torchaudio
|
|
|
|
if is_soundfile_available():
|
|
import soundfile as sf
|
|
|
|
|
|
@require_torch
|
|
class DiaModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=3, # need batch_size != num_hidden_layers
|
|
seq_length=7,
|
|
max_length=50,
|
|
is_training=True,
|
|
vocab_size=100,
|
|
hidden_size=16,
|
|
intermediate_size=37,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=2,
|
|
head_dim=8,
|
|
decoder_hidden_size=32, # typically larger than encoder
|
|
hidden_act="silu",
|
|
eos_token_id=97, # special tokens all occur after eos
|
|
pad_token_id=98,
|
|
bos_token_id=99,
|
|
delay_pattern=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.max_length = max_length
|
|
self.is_training = is_training
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.head_dim = head_dim
|
|
self.decoder_hidden_size = decoder_hidden_size
|
|
self.hidden_act = hidden_act
|
|
self.eos_token_id = eos_token_id
|
|
self.pad_token_id = pad_token_id
|
|
self.bos_token_id = bos_token_id
|
|
# Set default delay pattern if not provided
|
|
self.delay_pattern = delay_pattern if delay_pattern is not None else [0, 1, 2]
|
|
self.num_channels = len(self.delay_pattern)
|
|
|
|
def get_config(self):
|
|
encoder_config = DiaEncoderConfig(
|
|
max_position_embeddings=self.max_length,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
hidden_size=self.hidden_size,
|
|
num_attention_heads=self.num_attention_heads,
|
|
num_key_value_heads=self.num_attention_heads, # same as num_attention_heads for testing
|
|
head_dim=self.head_dim,
|
|
intermediate_size=self.intermediate_size,
|
|
vocab_size=self.vocab_size,
|
|
hidden_act=self.hidden_act,
|
|
)
|
|
|
|
decoder_config = DiaDecoderConfig(
|
|
max_position_embeddings=self.max_length,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
hidden_size=self.decoder_hidden_size,
|
|
intermediate_size=self.intermediate_size,
|
|
num_attention_heads=self.num_attention_heads,
|
|
num_key_value_heads=1, # GQA
|
|
head_dim=self.head_dim,
|
|
cross_num_attention_heads=self.num_attention_heads,
|
|
cross_head_dim=self.head_dim,
|
|
cross_num_key_value_heads=1, # GQA
|
|
cross_hidden_size=self.hidden_size, # match encoder hidden size
|
|
vocab_size=self.vocab_size,
|
|
hidden_act=self.hidden_act,
|
|
num_channels=self.num_channels,
|
|
)
|
|
|
|
config = DiaConfig(
|
|
encoder_config=encoder_config,
|
|
decoder_config=decoder_config,
|
|
eos_token_id=self.eos_token_id,
|
|
pad_token_id=self.pad_token_id,
|
|
bos_token_id=self.bos_token_id,
|
|
delay_pattern=self.delay_pattern,
|
|
)
|
|
|
|
return config
|
|
|
|
def prepare_config_and_inputs(self) -> tuple[DiaConfig, dict]:
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
|
attention_mask = input_ids.ne(self.pad_token_id)
|
|
|
|
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length, self.num_channels], self.vocab_size)
|
|
decoder_attention_mask = decoder_input_ids[..., 0].ne(self.pad_token_id)
|
|
|
|
config = self.get_config()
|
|
inputs_dict = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"decoder_input_ids": decoder_input_ids,
|
|
"decoder_attention_mask": decoder_attention_mask,
|
|
}
|
|
return config, inputs_dict
|
|
|
|
def prepare_config_and_inputs_for_common(self) -> tuple[DiaConfig, dict]:
|
|
config, inputs_dict = self.prepare_config_and_inputs()
|
|
return config, inputs_dict
|
|
|
|
def create_and_check_model_forward(self, config, inputs_dict):
|
|
model = DiaModel(config=config).to(torch_device).eval()
|
|
|
|
input_ids = inputs_dict["input_ids"]
|
|
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
|
|
|
# first forward pass
|
|
last_hidden_state = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
|
|
|
|
self.parent.assertTrue(
|
|
last_hidden_state.shape, (self.batch_size, self.seq_length, config.decoder_config.hidden_size)
|
|
)
|
|
|
|
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
|
model = DiaModel(config=config).to(torch_device).eval()
|
|
outputs = model(**inputs_dict)
|
|
|
|
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
|
last_hidden_state = outputs.last_hidden_state
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
encoder = model.get_encoder()
|
|
encoder.save_pretrained(tmpdirname)
|
|
encoder = DiaEncoder.from_pretrained(tmpdirname).to(torch_device)
|
|
|
|
encoder_last_hidden_state_2 = encoder(
|
|
input_ids=inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"]
|
|
)[0]
|
|
|
|
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 3e-3)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
decoder = model.get_decoder()
|
|
decoder.save_pretrained(tmpdirname)
|
|
decoder = DiaDecoder.from_pretrained(tmpdirname).to(torch_device)
|
|
|
|
last_hidden_state_2 = decoder(
|
|
input_ids=inputs_dict["decoder_input_ids"],
|
|
attention_mask=inputs_dict["decoder_attention_mask"],
|
|
encoder_hidden_states=encoder_last_hidden_state,
|
|
)[0]
|
|
|
|
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 3e-3)
|
|
|
|
|
|
@require_torch
|
|
class DiaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (DiaModel, DiaForConditionalGeneration) if is_torch_available() else ()
|
|
# We only allow greedy search / sampling with one sequence; see `skip_non_greedy_generate`
|
|
all_generative_model_classes = (DiaForConditionalGeneration,)
|
|
# TODO: support new pipeline behavior in tests
|
|
pipeline_model_mapping = {}
|
|
# pipeline_model_mapping = {"text-to-audio": DiaForConditionalGeneration} if is_torch_available() else {}
|
|
test_pruning = False
|
|
test_head_masking = False
|
|
test_resize_embeddings = False
|
|
is_encoder_decoder = True
|
|
# Indicates VLMs usually but there are many audio models which are also composite
|
|
_is_composite = True
|
|
|
|
def setUp(self):
|
|
self.model_tester = DiaModelTester(self)
|
|
# Skipping `has_text_modality` but manually testing down below
|
|
self.config_tester = ConfigTester(self, has_text_modality=False, config_class=DiaConfig)
|
|
self.skip_non_greedy_generate()
|
|
|
|
def skip_non_greedy_generate(self):
|
|
skippable_tests = [
|
|
"test_sample_generate_dict_output", # return sequences > 1
|
|
"test_beam",
|
|
"test_group_beam",
|
|
"test_constrained_beam",
|
|
"test_contrastive",
|
|
"test_assisted",
|
|
"test_dola",
|
|
"test_prompt_lookup",
|
|
"test_model_parallel_beam_search",
|
|
"test_generate_without_input_ids",
|
|
"test_generate_with_head_masking",
|
|
]
|
|
|
|
for test in skippable_tests:
|
|
if self._testMethodName.startswith(test):
|
|
self.skipTest(reason="Dia only supports greedy search / sampling with one sequence.")
|
|
|
|
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
|
"""Overriden to account for the 2D flattened structure"""
|
|
inputs_dict = copy.deepcopy(inputs_dict)
|
|
|
|
if return_labels:
|
|
inputs_dict["labels"] = torch.ones(
|
|
(
|
|
self.model_tester.batch_size * self.model_tester.num_channels,
|
|
self.model_tester.seq_length,
|
|
),
|
|
dtype=torch.long,
|
|
device=torch_device,
|
|
)
|
|
|
|
return inputs_dict
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
# Manual testing because of composite configs
|
|
config = self.model_tester.prepare_config_and_inputs()[0]
|
|
self.assertTrue(hasattr(config.encoder_config, "vocab_size"), msg="Encoder `vocab_size` does not exist")
|
|
self.assertTrue(hasattr(config.decoder_config, "vocab_size"), msg="Decoder `vocab_size` does not exist")
|
|
|
|
def test_model_forward(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
|
|
|
@is_flaky
|
|
def test_encoder_decoder_model_standalone(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
|
|
|
|
# Overriding shape checks as Dia has different shapes on encoder/decoder using a composite config
|
|
# + additional special cases where 3D x 2D meshes confuse the expected shape
|
|
def _check_logits(self, batch_size, logits, config):
|
|
batch_size *= len(config.delay_pattern) # Account for flattening
|
|
vocab_size = config.decoder_config.vocab_size
|
|
self.assertIsInstance(logits, tuple)
|
|
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
|
|
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
|
vocab_diff = vocab_size - logits[0].shape[-1]
|
|
self.assertTrue(vocab_diff in [0, 1])
|
|
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
|
|
|
|
def _check_attentions_for_generate(
|
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
|
):
|
|
self.assertIsInstance(attentions, tuple)
|
|
self.assertListEqual(
|
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
|
)
|
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
|
|
|
use_cache = decoder_past_key_values is not None
|
|
has_static_cache = isinstance(decoder_past_key_values, StaticCache)
|
|
|
|
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
|
|
# token(s)
|
|
for generated_length, iter_attentions in enumerate(attentions):
|
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
|
if use_cache and generated_length > 0:
|
|
model_input_length = 1
|
|
else:
|
|
model_input_length = prompt_length + generated_length
|
|
query_length = (
|
|
prompt_length + generated_length
|
|
if not has_static_cache
|
|
else decoder_past_key_values.get_max_cache_shape()
|
|
)
|
|
|
|
expected_shape = (
|
|
batch_size,
|
|
config.decoder_config.num_attention_heads, # Decoder config
|
|
model_input_length,
|
|
query_length,
|
|
)
|
|
# check attn size
|
|
self.assertListEqual(
|
|
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
|
)
|
|
|
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
|
# Encoder config
|
|
encoder_expected_shape = (batch_size, config.encoder_config.num_attention_heads, prompt_length, prompt_length)
|
|
self.assertIsInstance(attentions, tuple)
|
|
self.assertListEqual(
|
|
[layer_attentions.shape for layer_attentions in attentions],
|
|
[encoder_expected_shape] * len(attentions),
|
|
)
|
|
|
|
def _check_hidden_states_for_generate(
|
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
|
):
|
|
self.assertIsInstance(hidden_states, tuple)
|
|
self.assertListEqual(
|
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
[True] * len(hidden_states),
|
|
)
|
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
|
|
|
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
|
|
# new token(s)
|
|
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
|
if use_cache and generated_length > 0:
|
|
model_input_length = 1
|
|
else:
|
|
model_input_length = prompt_length + generated_length
|
|
|
|
# check hidden size
|
|
# we can have different hidden sizes between encoder and decoder --> check both
|
|
expected_shape_encoder = (batch_size, model_input_length, config.encoder_config.hidden_size)
|
|
expected_shape_decoder = (batch_size, model_input_length, config.decoder_config.hidden_size)
|
|
self.assertTrue(
|
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
|
|
== [expected_shape_encoder] * len(iter_hidden_states)
|
|
or [layer_hidden_states.shape for layer_hidden_states in iter_hidden_states]
|
|
== [expected_shape_decoder] * len(iter_hidden_states)
|
|
)
|
|
|
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
|
# Encoder config
|
|
encoder_expected_shape = (batch_size, prompt_length, config.encoder_config.hidden_size)
|
|
self.assertIsInstance(hidden_states, tuple)
|
|
self.assertListEqual(
|
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
|
[encoder_expected_shape] * len(hidden_states),
|
|
)
|
|
|
|
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
|
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
|
|
|
# we need the decoder config here
|
|
config = config.decoder_config
|
|
|
|
# (batch, head, seq_length, head_features)
|
|
expected_shape = (
|
|
batch_size,
|
|
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
|
cache_length,
|
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads,
|
|
)
|
|
|
|
if isinstance(decoder_past_key_values, Cache):
|
|
self.assertListEqual(
|
|
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
|
[expected_shape] * len(decoder_past_key_values.key_cache),
|
|
)
|
|
self.assertListEqual(
|
|
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
|
[expected_shape] * len(decoder_past_key_values.value_cache),
|
|
)
|
|
|
|
def _check_scores(self, batch_size, scores, generated_length, config):
|
|
# Special case where Dia keeps score in a 2D mesh of (bsz * channels, vocab)
|
|
vocab_size = config.decoder_config.vocab_size
|
|
expected_shape = (batch_size * len(config.delay_pattern), vocab_size)
|
|
self.assertIsInstance(scores, tuple)
|
|
self.assertEqual(len(scores), generated_length)
|
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
|
|
|
@require_torch_sdpa
|
|
def test_sdpa_can_dispatch_composite_models(self):
|
|
"""
|
|
Overwritten as it relies on hardcoded namings atm - checking for our case here specifically
|
|
"""
|
|
for model_class in self.all_model_classes:
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
model = model_class(config)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
model.save_pretrained(tmpdirname)
|
|
model = model_class.from_pretrained(tmpdirname)
|
|
|
|
sub_models_supporting_sdpa = [
|
|
(module._supports_sdpa or module._supports_attention_backend)
|
|
for name, module in model.named_modules()
|
|
if isinstance(module, PreTrainedModel) and name != ""
|
|
]
|
|
supports_sdpa_all_modules = (
|
|
all(sub_models_supporting_sdpa)
|
|
if len(sub_models_supporting_sdpa) > 0
|
|
else (model._supports_sdpa or model._supports_attention_backend)
|
|
)
|
|
|
|
if not supports_sdpa_all_modules:
|
|
with self.assertRaises(ValueError):
|
|
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
|
else:
|
|
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
|
for key in model_sdpa.config:
|
|
if isinstance(getattr(model_sdpa.config, key), PretrainedConfig):
|
|
sub_config = getattr(model_sdpa.config, key)
|
|
self.assertTrue(sub_config._attn_implementation == "sdpa")
|
|
|
|
@pytest.mark.generate
|
|
@unittest.skip(reason="Custom processor `DiaEOSDelayPatternLogitsProcessor` forces eos token.")
|
|
def test_generate_continue_from_past_key_values(self):
|
|
"""Only a small change due to the expected shapes"""
|
|
# Tests that we can continue generating from past key values, returned from a previous `generate` call
|
|
for model_class in self.all_generative_model_classes:
|
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
# Let's make it always:
|
|
# 1. use cache (for obvious reasons)
|
|
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
|
|
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
|
|
# continuation would force it to generate beyond an EOS token)
|
|
# 3. ignore `token_type_ids` for simplicity
|
|
# 4. ignore `forced_eos_token_id`, which requires further manipulation of the continuation inputs and is
|
|
# active by default on some models
|
|
# 5. ignore `encoder_no_repeat_ngram_size`, which is set by default in some encoder-decoder models. When
|
|
# we use their decoder as a stand-alone model, `encoder_no_repeat_ngram_size` actually prevents
|
|
# repetition exclusively from the prompt. This test relies on comparing one call vs 2 calls
|
|
# with cache, what is considered a prompt is different in the two cases.
|
|
|
|
if "token_type_ids" in inputs:
|
|
del inputs["token_type_ids"]
|
|
|
|
model = model_class(config).to(torch_device)
|
|
model.eval()
|
|
|
|
generate_kwargs = {
|
|
"pad_token_id": -1,
|
|
"eos_token_id": -1,
|
|
"forced_eos_token_id": None,
|
|
"encoder_no_repeat_ngram_size": 0,
|
|
"use_cache": True,
|
|
"do_sample": False,
|
|
"return_dict_in_generate": True,
|
|
"output_scores": True,
|
|
}
|
|
|
|
# Traditional way of generating text, with `return_dict_in_generate` to return the past key values
|
|
outputs = model.generate(**inputs, **generate_kwargs, max_new_tokens=4)
|
|
|
|
# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the
|
|
# inputs may need to be tweaked across `generate` calls (like the attention mask).
|
|
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=3)
|
|
|
|
# Continue from the tokens generated above, preparing the inputs accordingly
|
|
inputs["past_key_values"] = outputs_cached.past_key_values
|
|
new_attention_len = outputs_cached.sequences.shape[1] # the only real modification in this test
|
|
inputs["decoder_input_ids"] = outputs_cached.sequences
|
|
if "decoder_attention_mask" in inputs:
|
|
inputs["decoder_attention_mask"] = torch.nn.functional.pad(
|
|
inputs["decoder_attention_mask"],
|
|
(0, new_attention_len - inputs["decoder_attention_mask"].shape[1]),
|
|
mode="constant",
|
|
value=1,
|
|
)
|
|
|
|
first_caches_scores = outputs_cached.scores
|
|
outputs_cached = model.generate(**inputs, **generate_kwargs, max_new_tokens=1)
|
|
full_cached_scores = first_caches_scores + outputs_cached.scores
|
|
outputs_cached.scores = full_cached_scores
|
|
|
|
# The two sets of generated text and past kv should be equal to each other
|
|
self._check_similar_generate_outputs(outputs, outputs_cached)
|
|
for layer_idx in range(len(outputs_cached.past_key_values)):
|
|
for kv_idx in range(len(outputs_cached.past_key_values[layer_idx])):
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
outputs.past_key_values[layer_idx][kv_idx],
|
|
outputs_cached.past_key_values[layer_idx][kv_idx],
|
|
)
|
|
)
|
|
|
|
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
|
|
def test_past_key_values_format(self, custom_all_cache_shapes=None):
|
|
pass
|
|
|
|
@unittest.skip(reason="Indirectly checked in Dia through the generate methods.")
|
|
def test_hidden_states_output(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="Dia has too many mixed embedding types which would cause unintentional side effects, e.g. attempts at tying embeddings"
|
|
)
|
|
def test_model_get_set_embeddings(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
|
|
def test_torchscript_output_hidden_state(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="Theoretically works but kernel library causes issues.")
|
|
def test_torchscript_simple(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="Encoder-Decoder cache can not be initialized.")
|
|
def test_multi_gpu_data_parallel_forward(self):
|
|
pass
|
|
|
|
|
|
class DiaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|
"""
|
|
See https://gist.github.com/vasqu/0e3b06360373a4e612aa3b9a7c09185e for generating the integration tests
|
|
|
|
NOTE: We add a single `eos` line for the last channel which is skipped in the original Dia
|
|
(It doesn't change the behaviour as we cut by the eos token position)
|
|
"""
|
|
|
|
def setUp(self):
|
|
# it's a dummy ckpt but should suffice for testing purposes
|
|
self.model_checkpoint = "AntonV/Dia-1.6B"
|
|
self.sampling_rate = 44100
|
|
|
|
# prepare audio
|
|
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=self.sampling_rate))
|
|
audio_sample_1 = librispeech_dummy[-1]["audio"]["array"]
|
|
audio_sample_2 = librispeech_dummy[-2]["audio"]["array"]
|
|
# 10 and 5 codebooks as prefix - saved as files as we need wav files for the original Dia
|
|
dac_chunk_len = 512
|
|
self.audio_prompt_1_path = "/tmp/dia_test_sample_1.mp3"
|
|
self.audio_prompt_2_path = "/tmp/dia_test_sample_2.mp3"
|
|
sf.write(self.audio_prompt_1_path, audio_sample_1[: (dac_chunk_len * 10)], self.sampling_rate)
|
|
sf.write(self.audio_prompt_2_path, audio_sample_2[: (dac_chunk_len * 5)], self.sampling_rate)
|
|
|
|
def tearDown(self):
|
|
pathlib.Path(self.audio_prompt_1_path).unlink()
|
|
pathlib.Path(self.audio_prompt_2_path).unlink()
|
|
cleanup(torch_device, gc_collect=True)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_dia_model_integration_generate_tts(self):
|
|
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
|
|
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
|
|
inputs = processor(text=text, padding=True, return_tensors="pt").to(torch_device)
|
|
|
|
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
|
|
outputs = model.generate(**inputs, max_new_tokens=32, do_sample=False)
|
|
|
|
# fmt: off
|
|
EXPECTED_OUTPUT_TOKENS = torch.tensor([[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 804, 10, 524, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 804, 10, 674, 967, 1026, 1026, 1026, 1026],
|
|
[ 568, 804, 10, 674, 364, 360, 1026, 1026, 1026],
|
|
[ 568, 804, 10, 674, 364, 981, 728, 1026, 1026],
|
|
[ 568, 804, 10, 674, 364, 981, 741, 550, 1026],
|
|
[ 568, 804, 10, 674, 364, 981, 568, 378, 90],
|
|
[1024, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 804, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 1024, 10, 674, 364, 981, 568, 378, 731],
|
|
[1025, 1025, 1024, 674, 364, 981, 568, 378, 731],
|
|
[1025, 1025, 1025, 1024, 364, 981, 568, 378, 731],
|
|
[1025, 1025, 1025, 1025, 1024, 981, 568, 378, 731],
|
|
[1025, 1025, 1025, 1025, 1025, 1024, 568, 378, 731],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 378, 731],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 731],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]],
|
|
|
|
[[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 568, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 698, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 778, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 778, 338, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 697, 10, 524, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 288, 476, 649, 967, 1026, 1026, 1026, 1026],
|
|
[ 592, 740, 386, 674, 364, 360, 1026, 1026, 1026],
|
|
[ 592, 402, 386, 347, 362, 981, 728, 1026, 1026],
|
|
[ 592, 402, 721, 728, 327, 981, 741, 550, 1026],
|
|
[ 592, 402, 721, 728, 460, 62, 676, 378, 90],
|
|
[1024, 402, 721, 728, 837, 595, 195, 982, 784],
|
|
[1025, 402, 721, 677, 497, 102, 692, 24, 330],
|
|
[1025, 402, 721, 677, 511, 102, 503, 871, 609],
|
|
[1025, 402, 721, 677, 511, 96, 801, 871, 894],
|
|
[1025, 402, 721, 677, 511, 745, 314, 498, 775],
|
|
[1025, 402, 721, 677, 511, 745, 314, 498, 105],
|
|
[1025, 402, 721, 677, 511, 745, 314, 861, 889],
|
|
[1025, 893, 721, 677, 511, 744, 314, 871, 353],
|
|
[1025, 1024, 888, 677, 511, 744, 314, 871, 332],
|
|
[1025, 1025, 1024, 518, 511, 744, 314, 871, 366],
|
|
[1025, 1025, 1025, 1024, 611, 744, 314, 871, 366],
|
|
[1025, 1025, 1025, 1025, 1024, 980, 314, 871, 366],
|
|
[1025, 1025, 1025, 1025, 1025, 1024, 45, 124, 366],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 871, 366],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 719],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]]])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(outputs.cpu(), EXPECTED_OUTPUT_TOKENS)
|
|
|
|
@slow
|
|
@require_torch_accelerator
|
|
def test_dia_model_integration_generate_audio_context(self):
|
|
text = ["[S1] Dia is an open weights text to dialogue model.", "This is a test"]
|
|
audio_sample_1 = torchaudio.load(self.audio_prompt_1_path, channels_first=True)[0].squeeze().numpy()
|
|
audio_sample_2 = torchaudio.load(self.audio_prompt_2_path, channels_first=True)[0].squeeze().numpy()
|
|
audio = [audio_sample_1, audio_sample_2]
|
|
|
|
processor = DiaProcessor.from_pretrained(self.model_checkpoint)
|
|
inputs = processor(text=text, audio=audio, padding=True, return_tensors="pt").to(torch_device)
|
|
|
|
model = DiaForConditionalGeneration.from_pretrained(self.model_checkpoint).to(torch_device)
|
|
# dia has right padding while we have left padding (for faster prefill)
|
|
# additionally we have new tokens vs dia's max tokens (hence we compare each in the respective settings)
|
|
outputs_1 = model.generate(**inputs, max_new_tokens=22, do_sample=False)
|
|
outputs_2 = model.generate(**inputs, max_new_tokens=27, do_sample=False)
|
|
|
|
# fmt: off
|
|
EXPECTED_OUTPUT_TOKENS_1 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 578, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 592, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 494, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 501, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 204, 34, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 254, 915, 863, 1026, 1026, 1026, 1026, 1026],
|
|
[ 330, 215, 458, 313, 50, 1026, 1026, 1026, 1026],
|
|
[ 330, 615, 529, 216, 801, 237, 1026, 1026, 1026],
|
|
[ 330, 580, 563, 233, 337, 37, 1018, 1026, 1026],
|
|
[ 330, 567, 530, 753, 607, 179, 954, 242, 1026],
|
|
[ 330, 627, 6, 1010, 500, 189, 598, 858, 247],
|
|
[1024, 432, 480, 530, 122, 3, 788, 149, 814],
|
|
[1025, 875, 826, 458, 98, 540, 181, 122, 608],
|
|
[1025, 495, 840, 413, 337, 784, 591, 150, 1017],
|
|
[1025, 808, 189, 137, 445, 0, 227, 658, 345],
|
|
[1025, 397, 89, 753, 1016, 173, 984, 0, 910],
|
|
[1025, 875, 460, 934, 50, 335, 670, 818, 722],
|
|
[1025, 875, 460, 762, 119, 372, 503, 858, 584],
|
|
[1025, 348, 555, 475, 469, 458, 963, 41, 664],
|
|
[1025, 1024, 852, 683, 761, 193, 595, 895, 885],
|
|
[1025, 1025, 1024, 135, 761, 902, 163, 623, 385],
|
|
[1025, 1025, 1025, 1024, 852, 282, 581, 623, 70],
|
|
[1025, 1025, 1025, 1025, 1024, 41, 661, 790, 977],
|
|
[1025, 1025, 1025, 1025, 1025, 1024, 580, 401, 464],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 756, 61],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 752],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
|
|
|
|
EXPECTED_OUTPUT_TOKENS_2 = torch.tensor([[1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 619, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1026, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 968, 1026, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 1007, 458, 1026, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 35, 266, 68, 1026, 1026, 1026, 1026, 1026],
|
|
[ 315, 359, 285, 811, 154, 1026, 1026, 1026, 1026],
|
|
[ 315, 906, 407, 297, 785, 649, 1026, 1026, 1026],
|
|
[ 315, 249, 678, 868, 899, 257, 950, 1026, 1026],
|
|
[ 315, 249, 217, 471, 292, 908, 196, 469, 1026],
|
|
[ 315, 249, 825, 771, 839, 802, 633, 590, 531],
|
|
[1024, 249, 150, 53, 126, 76, 794, 626, 442],
|
|
[1025, 249, 825, 218, 359, 864, 526, 626, 770],
|
|
[1025, 249, 150, 137, 530, 845, 877, 600, 111],
|
|
[1025, 249, 150, 287, 730, 991, 135, 259, 39],
|
|
[1025, 249, 825, 104, 198, 1020, 719, 625, 208],
|
|
[1025, 249, 825, 997, 602, 256, 859, 322, 518],
|
|
[1025, 668, 825, 979, 584, 256, 98, 665, 589],
|
|
[1025, 954, 458, 54, 206, 52, 244, 822, 599],
|
|
[1025, 1024, 104, 914, 435, 579, 860, 92, 661],
|
|
[1025, 1025, 1024, 848, 126, 74, 304, 92, 753],
|
|
[1025, 1025, 1025, 1024, 362, 376, 304, 586, 753],
|
|
[1025, 1025, 1025, 1025, 1024, 633, 996, 586, 83],
|
|
[1025, 1025, 1025, 1025, 1025, 1024, 179, 898, 928],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1024, 506, 102],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024, 79],
|
|
[1025, 1025, 1025, 1025, 1025, 1025, 1025, 1025, 1024]])
|
|
# fmt: on
|
|
|
|
torch.testing.assert_close(outputs_1[0].cpu(), EXPECTED_OUTPUT_TOKENS_1)
|
|
torch.testing.assert_close(outputs_2[1, 5:].cpu(), EXPECTED_OUTPUT_TOKENS_2) # left padding
|