mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[VITS] Add to TTA pipeline (#25906)
* [VITS] Add to TTA pipeline * Update tests/pipelines/test_pipelines_text_to_audio.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * remove extra spaces --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
This commit is contained in:
parent
be0e189bd3
commit
b439129e74
@ -1036,6 +1036,7 @@ MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
||||
# Model for Text-To-Waveform mapping
|
||||
("bark", "BarkModel"),
|
||||
("musicgen", "MusicgenForConditionalGeneration"),
|
||||
("vits", "VitsModel"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -56,8 +56,6 @@ class TextToAudioPipeline(Pipeline):
|
||||
if self.framework == "tf":
|
||||
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
|
||||
|
||||
self.forward_method = self.model.generate if self.model.can_generate() else self.model
|
||||
|
||||
self.vocoder = None
|
||||
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
|
||||
self.vocoder = (
|
||||
@ -110,8 +108,10 @@ class TextToAudioPipeline(Pipeline):
|
||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
||||
|
||||
# call the generate by defaults or the forward method if the model cannot generate
|
||||
output = self.forward_method(**model_inputs, **kwargs)
|
||||
if self.model.can_generate():
|
||||
output = self.model.generate(**model_inputs, **kwargs)
|
||||
else:
|
||||
output = self.model(**model_inputs, **kwargs)[0]
|
||||
|
||||
if self.vocoder is not None:
|
||||
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
||||
|
@ -37,7 +37,7 @@ from .test_pipelines_common import ANY
|
||||
@require_torch_or_tf
|
||||
class TextToAudioPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
||||
# for now only text_to_waveform and not text_to_spectrogram
|
||||
# for now only test text_to_waveform and not text_to_spectrogram
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@ -50,26 +50,21 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
}
|
||||
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
|
||||
# musicgen sampling_rate is not straightforward to get
|
||||
self.assertIsNone(outputs["sampling_rate"])
|
||||
|
||||
audio = outputs["audio"]
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# test two examples side-by-side
|
||||
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# test batching
|
||||
outputs = speech_generator(
|
||||
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||
)
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||
|
||||
@slow
|
||||
@ -77,8 +72,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
def test_large_model_pt(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
||||
|
||||
# test text-to-speech
|
||||
|
||||
forward_params = {
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
"do_sample": False,
|
||||
@ -86,7 +79,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
}
|
||||
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
|
||||
self.assertEqual(
|
||||
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||
outputs,
|
||||
@ -97,13 +89,10 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
["This is a test", "This is a second test"],
|
||||
forward_params=forward_params,
|
||||
)
|
||||
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# test other generation strategy
|
||||
|
||||
forward_params = {
|
||||
"do_sample": True,
|
||||
"semantic_max_new_tokens": 100,
|
||||
@ -111,9 +100,7 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
}
|
||||
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
|
||||
audio = outputs["audio"]
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# test using a speaker embedding
|
||||
@ -127,9 +114,7 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
forward_params=forward_params,
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
@slow
|
||||
@ -151,7 +136,6 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
"return_token_type_ids": False,
|
||||
"padding": "max_length",
|
||||
}
|
||||
|
||||
outputs = speech_generator(
|
||||
"This is a test",
|
||||
forward_params=forward_params,
|
||||
@ -163,28 +147,44 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
forward_params["history_prompt"] = history_prompt
|
||||
|
||||
# history_prompt is a torch.Tensor passed as a forward_param
|
||||
# if generation is successfull, it means that it was passed to the right device
|
||||
# if generation is successful, it means that it was passed to the right device
|
||||
outputs = speech_generator(
|
||||
"This is a test", forward_params=forward_params, preprocess_params=preprocess_params
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||
outputs,
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_vits_model_pt(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="facebook/mms-tts-eng", framework="pt")
|
||||
|
||||
outputs = speech_generator("This is a test")
|
||||
self.assertEqual(outputs["sampling_rate"], 16000)
|
||||
|
||||
audio = outputs["audio"]
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# test two examples side-by-side
|
||||
outputs = speech_generator(["This is a test", "This is a second test"])
|
||||
audio = [output["audio"] for output in outputs]
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# test batching
|
||||
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
|
||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||
return speech_generator, ["This is a test", "Another test"]
|
||||
|
||||
def run_pipeline_test(self, speech_generator, _):
|
||||
outputs = speech_generator("This is a test")
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), outputs["audio"])
|
||||
|
||||
forward_params = {"num_return_sequences": 2, "do_sample": True}
|
||||
|
||||
outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params)
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
Loading…
Reference in New Issue
Block a user