mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00

* add dia model * add tokenizer files * cleanup some stuff * brut copy paste code * rough cleanup of the modeling code * nuke some stuff * more nuking * more cleanups * updates * add mulitLayerEmbedding vectorization * nits * more modeling simplifications * updates * update rope * update rope * just fixup * update configuration files * more cleanup! * default config values * update * forgotten comma * another comma! * update, more cleanups * just more nits * more config cleanups * time for the encoder * fix * sa=mall nit * nits * n * refacto a bit * cleanup * update cv scipt * fix last issues * fix last nits * styling * small fixes * just run 1 generation * fixes * nits * fix conversion * fix * more fixes * full generate * ouf! * fixes! * updates * fix * fix cvrt * fixup * nits * delete wrong test * update * update * test tokenization * let's start changing things bit by bit - fix encoder step * removing custom generation, moving to GenerationMixin * add encoder decoder attention masks for generation * mask changes, correctness checked against ad29837 in dia repo * refactor a bit already --> next cache * too important not to push :) * minimal cleanup + more todos * make main overwrite modeling utils * add cfg filter & eos filter * add eos countdown & delay pattern * update eos countdown * add max step eos countdown * fix tests * fix some things * fix generation with testing * move cfg & eos stuff to logits processor * make RepetitionPenaltyLogitsProcessor flexible - can accept 3D scores like (batch_size, channel, vocab) * fix input_ids concatenation dimension in GenerationMixin for flexibility * Add DiaHangoverLogitsProcessor and DiaExponentialDecayLengthPenalty classes; refactor logits processing in DiaForConditionalGeneration to utilize new configurations and improve flexibility. * Add stopping criteria * refactor * move delay pattern from processor to modeling like musicgen. - add docs - change eos countdown to eos delay pattern * fix processor & fix tests * refactor types * refactor imports * format code * fix docstring to pass ci * add docstring to DiaConfig & add DiaModel to test * fix docstring * add docstring * fix some bugs * check * porting / merging results from other branch - IMPORTANT: it very likely breaks generation, the goal is to have a proper forward path first * experimental testing of left padding for first channel * whoops * Fix merge to make generation work * fix cfg filter * add position ids * add todos, break things * revert changes to generation --> we will force 2d but go 3d on custom stuff * refactor a lot, change prepare decoder ids to work with left padding (needs testing), add todos * some first fixes to get to 10. in generation * some more generation fixes / adjustment * style + rope fixes * move cfg out, simplify a few things, more todos * nit * start working on custom logit processors * nit * quick fixes * cfg top k * more refactor of logits processing, needs a decision if gen config gets the new attributes or if we move it to config or similar * lets keep changes to core code minimal, only eos scaling is questionable atm * simpler eos delay logits processor * that was for debugging :D * proof of concept rope * small fix on device mismatch * cfg fixes + delay logits max len * transformers rope * modular dia * more cleanup * keep modeling consistently 3D, generate handles 2D internally * decoder starts with bos if nothing * post processing prototype * style * lol * force sample / greedy + fixes on padding * style * fixup tokenization * nits * revert * start working on dia tests * fix a lot of tests * more test fixes * nit * more test fixes + some features to simplify code more * more cleanup * forgot that one * autodocs * small consistency fixes * fix regression * small fixes * dia feature extraction * docs * wip processor * fix processor order * processing goes brrr * transpose before * small fix * fix major bug but needs now a closer look into the custom processors esp cfg * small thing on logits * nits * simplify indices and shifts * add simpler version of padding tests back (temporarily) * add logit processor tests * starting tests on processor * fix mask application during generation * some fixes on the weights conversion * style + fixup logits order * simplify conversion * nit * remove padding tests * nits on modeling * hmm * fix tests * trigger * probably gonna be reverted, just a quick design around audio tokenizer * fixup typing * post merge + more typing * initial design for audio tokenizer * more design changes * nit * more processor tests and style related things * add to init * protect import * not sure why tbh * add another protect * more fixes * wow * it aint stopping :D * another missed type issue * ... * change design around audio tokenizer to prioritize init and go for auto - in regards to the review * change to new causal mask function + docstrings * change ternary * docs * remove todo, i dont think its essential tbh * remove pipeline as current pipelines do not fit in the current scheme, same as csm * closer to wrapping up the processor * text to audio, just for demo purposes (will likely be reverted) * check if it's this * save audio function * ensure no grad * fixes on prefixed audio, hop length is used via preprocess dac, device fixes * integration tests (tested locally on a100) + some processor utils / fixes * style * nits * another round of smaller things * docs + some fixes (generate one might be big) * msytery solved * small fix on conversion * add abstract audio tokenizer, change init check to abstract class * nits * update docs + fix some processing :D * change inheritance scheme for audio tokenizer * delete dead / unnecessary code in copied generate loop * last nits on new pipeline behavior (+ todo on tests) + style * trigger --------- Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Vasqu <antonprogamer@gmail.com>
270 lines
12 KiB
Python
270 lines
12 KiB
Python
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from parameterized import parameterized
|
|
|
|
from transformers import DacModel, DiaFeatureExtractor, DiaProcessor, DiaTokenizer
|
|
from transformers.testing_utils import require_torch
|
|
from transformers.utils import is_torch_available
|
|
|
|
|
|
if is_torch_available:
|
|
import torch
|
|
|
|
|
|
# Copied from tests.utils.test_modeling_utils.check_models_equal
|
|
def check_models_equal(model1, model2):
|
|
models_are_equal = True
|
|
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
|
|
if model1_p.data.ne(model2_p.data).sum() > 0:
|
|
models_are_equal = False
|
|
|
|
return models_are_equal
|
|
|
|
|
|
@require_torch
|
|
class DiaProcessorTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.checkpoint = "AntonV/Dia-1.6B"
|
|
self.audio_tokenizer_checkpoint = "descript/dac_44khz"
|
|
self.tmpdirname = tempfile.mkdtemp()
|
|
|
|
# Audio tokenizer is a bigger model so we will reuse this if possible
|
|
self.processor = DiaProcessor(
|
|
tokenizer=self.get_tokenizer(),
|
|
feature_extractor=self.get_feature_extractor(),
|
|
audio_tokenizer=self.get_audio_tokenizer(),
|
|
)
|
|
|
|
# Default audio values based on Dia and Dac
|
|
self.pad_id = 1025
|
|
self.bos_id = 1026
|
|
self.dac_chunk_len = 512
|
|
self.delay_pattern = [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
|
|
|
def get_tokenizer(self, **kwargs):
|
|
return DiaTokenizer.from_pretrained(self.checkpoint, **kwargs)
|
|
|
|
def get_feature_extractor(self, **kwargs):
|
|
return DiaFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
|
|
|
|
def get_audio_tokenizer(self, **kwargs):
|
|
return DacModel.from_pretrained(self.audio_tokenizer_checkpoint, **kwargs)
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.tmpdirname)
|
|
del self.processor
|
|
|
|
def test_save_load_pretrained_default(self):
|
|
tokenizer = self.get_tokenizer()
|
|
feature_extractor = self.get_feature_extractor()
|
|
audio_tokenizer = self.get_audio_tokenizer()
|
|
|
|
processor = DiaProcessor(
|
|
tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer
|
|
)
|
|
|
|
processor.save_pretrained(self.tmpdirname)
|
|
processor = DiaProcessor.from_pretrained(self.tmpdirname)
|
|
|
|
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
|
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
|
|
|
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
|
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
|
|
|
|
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer.__class__.__name__)
|
|
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer.name_or_path)
|
|
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer))
|
|
self.assertIsInstance(processor.audio_tokenizer, DacModel)
|
|
|
|
def test_save_load_pretrained_additional_features(self):
|
|
processor = DiaProcessor(
|
|
tokenizer=self.get_tokenizer(),
|
|
feature_extractor=self.get_feature_extractor(),
|
|
audio_tokenizer=self.get_audio_tokenizer(),
|
|
)
|
|
processor.save_pretrained(self.tmpdirname)
|
|
|
|
tokenizer_add_kwargs = self.get_tokenizer()
|
|
feature_extractor_add_kwargs = self.get_feature_extractor()
|
|
audio_tokenizer_add_kwargs = self.get_audio_tokenizer()
|
|
|
|
processor = DiaProcessor.from_pretrained(self.tmpdirname)
|
|
|
|
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
|
self.assertIsInstance(processor.tokenizer, DiaTokenizer)
|
|
|
|
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
|
self.assertIsInstance(processor.feature_extractor, DiaFeatureExtractor)
|
|
|
|
self.assertEqual(processor.audio_tokenizer.__class__.__name__, audio_tokenizer_add_kwargs.__class__.__name__)
|
|
self.assertEqual(processor.audio_tokenizer.name_or_path, audio_tokenizer_add_kwargs.name_or_path)
|
|
self.assertTrue(check_models_equal(processor.audio_tokenizer, audio_tokenizer_add_kwargs))
|
|
self.assertIsInstance(processor.audio_tokenizer, DacModel)
|
|
|
|
def test_model_input_names(self):
|
|
tokenizer = self.get_tokenizer()
|
|
|
|
self.assertListEqual(
|
|
self.processor.model_input_names,
|
|
list(dict.fromkeys(tokenizer.model_input_names + ["decoder_input_ids", "decoder_attention_mask"])),
|
|
msg="`processor` model input names do not match the expected names.",
|
|
)
|
|
|
|
def test_tokenize(self):
|
|
tokenizer = self.get_tokenizer()
|
|
random_text = ["This is a processing test for tokenization", "[S1] Dia template style [S2] Nice"]
|
|
|
|
input_tokenizer = tokenizer(random_text, padding=True, return_tensors="pt")
|
|
input_processor = self.processor(random_text)
|
|
|
|
for key in input_tokenizer.keys():
|
|
self.assertTrue((input_tokenizer[key] == input_processor[key]).all())
|
|
|
|
def test_no_audio(self):
|
|
random_text = ["Dummy Input"] * 2
|
|
input_processor = self.processor(random_text)
|
|
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
|
|
|
|
# full mask with +1 for bos
|
|
self.assertTrue(audio_mask.sum() == (max(self.delay_pattern) + 1) * len(random_text))
|
|
self.assertTrue(
|
|
audio_tokens.shape
|
|
== (
|
|
len(random_text),
|
|
max(self.delay_pattern) + 1,
|
|
len(self.delay_pattern),
|
|
)
|
|
)
|
|
|
|
for channel_idx, delay in enumerate(self.delay_pattern):
|
|
expected_sequence = torch.ones(size=(audio_tokens.shape[:-1])) * self.pad_id
|
|
expected_sequence[:, : delay + 1] = self.bos_id
|
|
self.assertTrue((audio_tokens[..., channel_idx] == expected_sequence).all())
|
|
|
|
def test_audio(self):
|
|
audio_tokenizer = self.get_audio_tokenizer()
|
|
feature_extractor = self.get_feature_extractor()
|
|
|
|
random_text = ["Dummy Input"] * 2
|
|
# Dac only starts accepting audio from a certain length (ensured via >=1024)
|
|
raw_speeches = [np.random.rand(2048).astype(np.float32), np.random.rand(1024).astype(np.float32)]
|
|
input_processor = self.processor(random_text, raw_speeches)
|
|
audio_tokens, audio_mask = input_processor["decoder_input_ids"], input_processor["decoder_attention_mask"]
|
|
|
|
sequence_len = audio_mask.shape[1]
|
|
for batch_idx, speech in enumerate(raw_speeches):
|
|
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
|
|
codebooks = audio_tokenizer(raw_audio).audio_codes.transpose(1, 2)
|
|
|
|
pad_len = sequence_len - audio_mask.sum(dim=-1)[batch_idx]
|
|
for channel_idx, delay in enumerate(self.delay_pattern):
|
|
# Left padding filled bos, right padding (delay) are pad
|
|
start_idx = pad_len + delay + 1
|
|
end_idx = start_idx + codebooks.shape[1]
|
|
|
|
encoded_sequence = audio_tokens[batch_idx, :, channel_idx]
|
|
expected_sequence = torch.ones(size=(sequence_len,)) * self.pad_id
|
|
expected_sequence[:start_idx] = self.bos_id
|
|
expected_sequence[start_idx:end_idx] = codebooks[0, :, channel_idx]
|
|
|
|
self.assertTrue((encoded_sequence == expected_sequence).all())
|
|
|
|
# Just to make sure the masking correctly only ignores bos tokens
|
|
self.assertTrue((audio_tokens[~audio_mask.bool()] == self.bos_id).all())
|
|
|
|
@parameterized.expand([([1, 1],), ([1, 5],), ([2, 4, 6],)])
|
|
def test_decode_audio(self, audio_lens):
|
|
feature_extractor = self.get_feature_extractor()
|
|
audio_tokenizer = self.get_audio_tokenizer()
|
|
|
|
random_text = ["Dummy Input"] * len(audio_lens)
|
|
raw_speeches = [np.random.rand(self.dac_chunk_len * l).astype(np.float32) for l in audio_lens]
|
|
# we need eos (given if training) to decode properly, also enforced via custom logits processor
|
|
input_processor = self.processor(random_text, raw_speeches, generation=False)
|
|
audio_tokens = input_processor["decoder_input_ids"]
|
|
|
|
decoded_speeches = self.processor.batch_decode(audio_tokens)
|
|
for batch_idx, speech in enumerate(raw_speeches):
|
|
raw_audio = feature_extractor(speech, return_tensors="pt")["input_values"]
|
|
codebooks = audio_tokenizer(raw_audio).audio_codes
|
|
|
|
decoded_audio = decoded_speeches[batch_idx]
|
|
expected_audio = audio_tokenizer.decode(audio_codes=codebooks).audio_values
|
|
|
|
self.assertTrue((expected_audio == decoded_audio).all())
|
|
self.assertTrue(decoded_speeches[batch_idx].shape[-1] == audio_lens[batch_idx] * self.dac_chunk_len)
|
|
|
|
@parameterized.expand([(1, 2, [0, 1, 4]), (2, 4, [1, 3, 2]), (4, 8, [0, 5, 7])])
|
|
def test_delay_in_audio(self, bsz, seq_len, delay_pattern):
|
|
# static functions which are crucial, hence we also test them here
|
|
build_indices_fn = DiaProcessor.build_indices
|
|
delay_fn = DiaProcessor.apply_audio_delay
|
|
|
|
bos, pad = -2, -1
|
|
num_channels = len(delay_pattern)
|
|
|
|
audio_input = torch.arange(bsz * seq_len * num_channels).view(bsz, seq_len, num_channels)
|
|
# imitate a delay mask with zeroes
|
|
audio_input = torch.cat([audio_input, torch.zeros(size=(bsz, max(delay_pattern), num_channels))], dim=1)
|
|
|
|
precomputed_idx = build_indices_fn(
|
|
bsz=bsz,
|
|
seq_len=seq_len + max(delay_pattern),
|
|
num_channels=num_channels,
|
|
delay_pattern=delay_pattern,
|
|
revert=False,
|
|
)
|
|
delayed_audio_out = delay_fn(
|
|
audio=audio_input,
|
|
pad_token_id=pad,
|
|
bos_token_id=bos,
|
|
precomputed_idx=precomputed_idx,
|
|
)
|
|
|
|
# every channel idx is shifted by delay_pattern[idx]
|
|
delayed_audio_res = audio_input.clone()
|
|
for idx, delay in enumerate(delay_pattern):
|
|
delayed_audio_res[:, :delay, idx] = bos
|
|
remaining_input = seq_len + max(delay_pattern) - delay
|
|
delayed_audio_res[:, delay:, idx] = audio_input[:, :remaining_input, idx]
|
|
|
|
self.assertTrue((delayed_audio_out == delayed_audio_res).all())
|
|
|
|
# we should get back to the original audio we had (when removing the delay pad)
|
|
bsz, new_seq_len, num_channels = delayed_audio_out.shape
|
|
precomputed_idx = build_indices_fn(
|
|
bsz=bsz,
|
|
seq_len=new_seq_len,
|
|
num_channels=num_channels,
|
|
delay_pattern=delay_pattern,
|
|
revert=True,
|
|
)
|
|
reverted_audio_out = delay_fn(
|
|
audio=delayed_audio_out,
|
|
pad_token_id=pad,
|
|
bos_token_id=bos,
|
|
precomputed_idx=precomputed_idx,
|
|
)
|
|
|
|
reverted_audio_res = audio_input.clone()[:, :seq_len]
|
|
|
|
self.assertTrue((reverted_audio_out[:, :seq_len] == reverted_audio_res).all())
|