mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Enrich TTS pipeline parameters naming (#26473)
* enrich TTS pipeline docstring for clearer forward_params use * change token leghts * update Pipeline parameters * correct docstring and make style * fix tests * make style * change music prompt Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * raise errors if generate_kwargs with forward-only models * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
147e8ce4ae
commit
0ed6729bb1
@ -43,6 +43,29 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
|
|
||||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
You can specify parameters passed to the model by using [`TextToAudioPipeline.__call__.forward_params`] or
|
||||||
|
[`TextToAudioPipeline.__call__.generate_kwargs`].
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
|
||||||
|
>>> music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||||
|
|
||||||
|
>>> # diversify the music generation by adding randomness with a high temperature and set a maximum music length
|
||||||
|
>>> generate_kwargs = {
|
||||||
|
... "do_sample": True,
|
||||||
|
... "temperature": 0.7,
|
||||||
|
... "max_new_tokens": 35,
|
||||||
|
... }
|
||||||
|
|
||||||
|
>>> outputs = music_generator("Techno music with high melodic riffs", generate_kwargs=generate_kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
|
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
|
||||||
`"text-to-audio"`.
|
`"text-to-audio"`.
|
||||||
@ -107,11 +130,26 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
def _forward(self, model_inputs, **kwargs):
|
def _forward(self, model_inputs, **kwargs):
|
||||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||||
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
||||||
|
forward_params = kwargs["forward_params"]
|
||||||
|
generate_kwargs = kwargs["generate_kwargs"]
|
||||||
|
|
||||||
if self.model.can_generate():
|
if self.model.can_generate():
|
||||||
output = self.model.generate(**model_inputs, **kwargs)
|
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||||
|
generate_kwargs = self._ensure_tensor_on_device(generate_kwargs, device=self.device)
|
||||||
|
|
||||||
|
# generate_kwargs get priority over forward_params
|
||||||
|
forward_params.update(generate_kwargs)
|
||||||
|
|
||||||
|
output = self.model.generate(**model_inputs, **forward_params)
|
||||||
else:
|
else:
|
||||||
output = self.model(**model_inputs, **kwargs)[0]
|
if len(generate_kwargs):
|
||||||
|
raise ValueError(
|
||||||
|
f"""You're using the `TextToAudioPipeline` with a forward-only model, but `generate_kwargs` is non empty.
|
||||||
|
For forward-only TTA models, please use `forward_params` instead of of
|
||||||
|
`generate_kwargs`. For reference, here are the `generate_kwargs` used here:
|
||||||
|
{generate_kwargs.keys()}"""
|
||||||
|
)
|
||||||
|
output = self.model(**model_inputs, **forward_params)[0]
|
||||||
|
|
||||||
if self.vocoder is not None:
|
if self.vocoder is not None:
|
||||||
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
||||||
@ -126,8 +164,14 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
Args:
|
Args:
|
||||||
text_inputs (`str` or `List[str]`):
|
text_inputs (`str` or `List[str]`):
|
||||||
The text(s) to generate.
|
The text(s) to generate.
|
||||||
forward_params (*optional*):
|
forward_params (`dict`, *optional*):
|
||||||
Parameters passed to the model generation/forward method.
|
Parameters passed to the model generation/forward method. `forward_params` are always passed to the
|
||||||
|
underlying model.
|
||||||
|
generate_kwargs (`dict`, *optional*):
|
||||||
|
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||||
|
complete overview of generate, check the [following
|
||||||
|
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation). `generate_kwargs` are
|
||||||
|
only passed to the underlying model if the latter is a generative model.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A `dict` or a list of `dict`: The dictionaries have two keys:
|
A `dict` or a list of `dict`: The dictionaries have two keys:
|
||||||
@ -141,14 +185,18 @@ class TextToAudioPipeline(Pipeline):
|
|||||||
self,
|
self,
|
||||||
preprocess_params=None,
|
preprocess_params=None,
|
||||||
forward_params=None,
|
forward_params=None,
|
||||||
|
generate_kwargs=None,
|
||||||
):
|
):
|
||||||
|
params = {
|
||||||
|
"forward_params": forward_params if forward_params else {},
|
||||||
|
"generate_kwargs": generate_kwargs if generate_kwargs else {},
|
||||||
|
}
|
||||||
|
|
||||||
if preprocess_params is None:
|
if preprocess_params is None:
|
||||||
preprocess_params = {}
|
preprocess_params = {}
|
||||||
if forward_params is None:
|
|
||||||
forward_params = {}
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, params, postprocess_params
|
||||||
|
|
||||||
def postprocess(self, waveform):
|
def postprocess(self, waveform):
|
||||||
output_dict = {}
|
output_dict = {}
|
||||||
|
@ -30,6 +30,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_utils import set_seed
|
||||||
|
|
||||||
from .test_pipelines_common import ANY
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
@ -174,6 +175,60 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
|||||||
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
|
outputs = speech_generator(["This is a test", "This is a second test"], batch_size=2)
|
||||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_forward_model_kwargs(self):
|
||||||
|
# use vits - a forward model
|
||||||
|
speech_generator = pipeline(task="text-to-audio", model="kakao-enterprise/vits-vctk", framework="pt")
|
||||||
|
|
||||||
|
# for reproducibility
|
||||||
|
set_seed(555)
|
||||||
|
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5})
|
||||||
|
audio = outputs["audio"]
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
# assert error if generate parameter
|
||||||
|
outputs = speech_generator("This is a test", forward_params={"speaker_id": 5, "do_sample": True})
|
||||||
|
|
||||||
|
forward_params = {"speaker_id": 5}
|
||||||
|
generate_kwargs = {"do_sample": True}
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
# assert error if generate_kwargs with forward-only models
|
||||||
|
outputs = speech_generator(
|
||||||
|
"This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs
|
||||||
|
)
|
||||||
|
self.assertTrue(np.abs(outputs["audio"] - audio).max() < 1e-5)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_generative_model_kwargs(self):
|
||||||
|
# use musicgen - a generative model
|
||||||
|
music_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||||
|
|
||||||
|
forward_params = {
|
||||||
|
"do_sample": True,
|
||||||
|
"max_new_tokens": 250,
|
||||||
|
}
|
||||||
|
|
||||||
|
# for reproducibility
|
||||||
|
set_seed(555)
|
||||||
|
outputs = music_generator("This is a test", forward_params=forward_params)
|
||||||
|
audio = outputs["audio"]
|
||||||
|
self.assertEqual(ANY(np.ndarray), audio)
|
||||||
|
|
||||||
|
# make sure generate kwargs get priority over forward params
|
||||||
|
forward_params = {
|
||||||
|
"do_sample": False,
|
||||||
|
"max_new_tokens": 250,
|
||||||
|
}
|
||||||
|
generate_kwargs = {"do_sample": True}
|
||||||
|
|
||||||
|
# for reproducibility
|
||||||
|
set_seed(555)
|
||||||
|
outputs = music_generator("This is a test", forward_params=forward_params, generate_kwargs=generate_kwargs)
|
||||||
|
self.assertListEqual(outputs["audio"].tolist(), audio.tolist())
|
||||||
|
|
||||||
def get_test_pipeline(self, model, tokenizer, processor):
|
def get_test_pipeline(self, model, tokenizer, processor):
|
||||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||||
return speech_generator, ["This is a test", "Another test"]
|
return speech_generator, ["This is a test", "Another test"]
|
||||||
|
Loading…
Reference in New Issue
Block a user