Enrich TTS pipeline parameters naming (#26473)

* enrich TTS pipeline docstring for clearer forward_params use

* change token leghts

* update Pipeline parameters

* correct docstring and make style

* fix tests

* make style

* change music prompt

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* raise errors if generate_kwargs with forward-only models

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe 2023-11-02 17:06:56 +00:00 committed by GitHub
parent 147e8ce4ae
commit 0ed6729bb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 110 additions and 7 deletions

View File

@ -43,6 +43,29 @@ class TextToAudioPipeline(Pipeline):
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
<Tip>
You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
[`TextToAudioPipeline.__call__.generate_kwargs`].
Example:
```python
>>> from transformers import pipeline
>>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
>>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
>>> generate_kwargs = {
... "do_sample": True,
... "temperature": 0.7,
... "max_new_tokens": 35,
... }
>>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
```
</Tip>
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
`"text-to-audio"`. `"text-to-audio"`.
@ -107,11 +130,26 @@ class TextToAudioPipeline(Pipeline):
def _forward(self, model_inputs, **kwargs): def _forward(self, model_inputs, **kwargs):
# we expect some kwargs to be additional tensors which need to be on the right device # we expect some kwargs to be additional tensors which need to be on the right device
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device) kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
forward_params = kwargs["forward_params"]
generate_kwargs = kwargs["generate_kwargs"]
if self.model.can_generate(): if self.model.can_generate():
output = self.model.generate(**model_inputs, **kwargs) # we expect some kwargs to be additional tensors which need to be on the right device
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
# generate_kwargs get priority over forward_params
forward_params.update(generate_kwargs)
output = self.model.generate(**model_inputs, **forward_params)
else: else:
output = self.model(**model_inputs, **kwargs)[0] if len(generate_kwargs):
raise ValueError(
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
For forward-only TTA models, please use `forward_params` instead of of
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
{generate_kwargs.keys()}"""
)
output = self.model(**model_inputs, **forward_params)[0]
if self.vocoder is not None: if self.vocoder is not None:
# in that case, the output is a spectrogram that needs to be converted into a waveform # in that case, the output is a spectrogram that needs to be converted into a waveform
@ -126,8 +164,14 @@ class TextToAudioPipeline(Pipeline):
Args: Args:
text_inputs (`str` or `List[str]`): text_inputs (`str` or `List[str]`):
The text(s) to generate. The text(s) to generate.
forward_params (*optional*): forward_params (`dict`, *optional*):
Parameters passed to the model generation/forward method. Parameters passed to the model generation/forward method. `forward_params` are always passed to the
underlying model.
generate_kwargs (`dict`, *optional*):
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
only passed to the underlying model if the latter is a generative model.
Return: Return:
A `dict` or a list of `dict`: The dictionaries have two keys: A `dict` or a list of `dict`: The dictionaries have two keys:
@ -141,14 +185,18 @@ class TextToAudioPipeline(Pipeline):
self, self,
preprocess_params=None, preprocess_params=None,
forward_params=None, forward_params=None,
generate_kwargs=None,
): ):
params = {
"forward_params": forward_params if forward_params else {},
"generate_kwargs": generate_kwargs if generate_kwargs else {},
}
if preprocess_params is None: if preprocess_params is None:
preprocess_params = {} preprocess_params = {}
if forward_params is None:
forward_params = {}
postprocess_params = {} postprocess_params = {}
return preprocess_params, forward_params, postprocess_params return preprocess_params, params, postprocess_params
def postprocess(self, waveform): def postprocess(self, waveform):
output_dict = {} output_dict = {}

View File

@ -30,6 +30,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.trainer_utils import set_seed
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
@ -174,6 +175,60 @@ class TextToAudioPipelineTests(unittest.TestCase):
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2) outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
@slow
@require_torch
def test_forward_model_kwargs(self):
# use vits - a forward model
speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt")
# for reproducibility
set_seed(555)
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5})
audio = outputs["audio"]
with self.assertRaises(TypeError):
# assert error if generate parameter
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True})
forward_params = {"speaker_id": 5}
generate_kwargs = {"do_sample": True}
with self.assertRaises(ValueError):
# assert error if generate_kwargs with forward-only models
outputs = speech_generator(
"This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs
)
self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5)
@slow
@require_torch
def test_generative_model_kwargs(self):
# use musicgen - a generative model
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
forward_params = {
"do_sample": True,
"max_new_tokens": 250,
}
# for reproducibility
set_seed(555)
outputs = music_generator("This is a test", forward_params=forward_params)
audio = outputs["audio"]
self.assertEqual(ANY(np.ndarray), audio)
# make sure generate kwargs get priority over forward params
forward_params = {
"do_sample": False,
"max_new_tokens": 250,
}
generate_kwargs = {"do_sample": True}
# for reproducibility
set_seed(555)
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
def get_test_pipeline(self, model, tokenizer, processor): def get_test_pipeline(self, model, tokenizer, processor):
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer) speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
return speech_generator, ["This is a test", "Another test"] return speech_generator, ["This is a test", "Another test"]