mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Enrich TTS pipeline parameters naming (#26473)
* enrich TTS pipeline docstring for clearer forward_params use * change token leghts * update Pipeline parameters * correct docstring and make style * fix tests * make style * change music prompt Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * raise errors if generate_kwargs with forward-only models * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
147e8ce4ae
commit
0ed6729bb1
@ -43,6 +43,29 @@ class TextToAudioPipeline(Pipeline):
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
<Tip>
|
||||
|
||||
You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
|
||||
[`TextToAudioPipeline.__call__.generate_kwargs`].
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||
|
||||
>>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
|
||||
>>> generate_kwargs = {
|
||||
... "do_sample": True,
|
||||
... "temperature": 0.7,
|
||||
... "max_new_tokens": 35,
|
||||
... }
|
||||
|
||||
>>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
|
||||
`"text-to-audio"`.
|
||||
@ -107,11 +130,26 @@ class TextToAudioPipeline(Pipeline):
|
||||
def _forward(self, model_inputs, **kwargs):
|
||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
||||
forward_params = kwargs["forward_params"]
|
||||
generate_kwargs = kwargs["generate_kwargs"]
|
||||
|
||||
if self.model.can_generate():
|
||||
output = self.model.generate(**model_inputs, **kwargs)
|
||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
|
||||
|
||||
# generate_kwargs get priority over forward_params
|
||||
forward_params.update(generate_kwargs)
|
||||
|
||||
output = self.model.generate(**model_inputs, **forward_params)
|
||||
else:
|
||||
output = self.model(**model_inputs, **kwargs)[0]
|
||||
if len(generate_kwargs):
|
||||
raise ValueError(
|
||||
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
|
||||
For forward-only TTA models, please use `forward_params` instead of of
|
||||
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
|
||||
{generate_kwargs.keys()}"""
|
||||
)
|
||||
output = self.model(**model_inputs, **forward_params)[0]
|
||||
|
||||
if self.vocoder is not None:
|
||||
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
||||
@ -126,8 +164,14 @@ class TextToAudioPipeline(Pipeline):
|
||||
Args:
|
||||
text_inputs (`str` or `List[str]`):
|
||||
The text(s) to generate.
|
||||
forward_params (*optional*):
|
||||
Parameters passed to the model generation/forward method.
|
||||
forward_params (`dict`, *optional*):
|
||||
Parameters passed to the model generation/forward method. `forward_params` are always passed to the
|
||||
underlying model.
|
||||
generate_kwargs (`dict`, *optional*):
|
||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||
complete overview of generate, check the [following
|
||||
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
|
||||
only passed to the underlying model if the latter is a generative model.
|
||||
|
||||
Return:
|
||||
A `dict` or a list of `dict`: The dictionaries have two keys:
|
||||
@ -141,14 +185,18 @@ class TextToAudioPipeline(Pipeline):
|
||||
self,
|
||||
preprocess_params=None,
|
||||
forward_params=None,
|
||||
generate_kwargs=None,
|
||||
):
|
||||
params = {
|
||||
"forward_params": forward_params if forward_params else {},
|
||||
"generate_kwargs": generate_kwargs if generate_kwargs else {},
|
||||
}
|
||||
|
||||
if preprocess_params is None:
|
||||
preprocess_params = {}
|
||||
if forward_params is None:
|
||||
forward_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
return preprocess_params, params, postprocess_params
|
||||
|
||||
def postprocess(self, waveform):
|
||||
output_dict = {}
|
||||
|
@ -30,6 +30,7 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
@ -174,6 +175,60 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
|
||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_forward_model_kwargs(self):
|
||||
# use vits - a forward model
|
||||
speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt")
|
||||
|
||||
# for reproducibility
|
||||
set_seed(555)
|
||||
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5})
|
||||
audio = outputs["audio"]
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
# assert error if generate parameter
|
||||
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True})
|
||||
|
||||
forward_params = {"speaker_id": 5}
|
||||
generate_kwargs = {"do_sample": True}
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# assert error if generate_kwargs with forward-only models
|
||||
outputs = speech_generator(
|
||||
"This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs
|
||||
)
|
||||
self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_generative_model_kwargs(self):
|
||||
# use musicgen - a generative model
|
||||
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||
|
||||
forward_params = {
|
||||
"do_sample": True,
|
||||
"max_new_tokens": 250,
|
||||
}
|
||||
|
||||
# for reproducibility
|
||||
set_seed(555)
|
||||
outputs = music_generator("This is a test", forward_params=forward_params)
|
||||
audio = outputs["audio"]
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# make sure generate kwargs get priority over forward params
|
||||
forward_params = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 250,
|
||||
}
|
||||
generate_kwargs = {"do_sample": True}
|
||||
|
||||
# for reproducibility
|
||||
set_seed(555)
|
||||
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
|
||||
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||
return speech_generator, ["This is a test", "Another test"]
|
||||
|
Loading…
Reference in New Issue
Block a user