mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add Text-To-Speech pipeline (#24952)
* add AutoModelForTextToSpeech class * add TTS pipeline and tessting * add docstrings to text_to_speech pipeline * fix torch dependency * corrector 'processor is None' case in Pipeline * correct repo id * modify text-to-speech -> text-to-audio * remove processor * rename text_to_speech pipelines files to text_audio * add textToWaveform and textToSpectrogram instead of textToAudio classes * update TTS pipeline to the bare minimum * update tests TTS pipeline * make style and erase useless import torch in TTS pipeline tests * modify how to check if generate or forward in TTS pipeline * remove unnecessary extra new lines * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * refactor input_texts -> text_inputs * correct docstrings of TTS.__call__ * correct the shape of generated waveform * take care of Bark tokenizer special case * correct run_pipeline_test TTS * make style * update TTS docstrings * address Sylvain nit refactors * make style * refactor into one liners * correct squeeze * correct way to test if forward or generate * Update output audio waveform shape * make style * correct import * modify how the TTS pipeline test if a model can generate * align shape output of TTS pipeline with consistent shape --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
c4c0ceff09
commit
b8f69d0d10
@ -318,6 +318,13 @@ Pipelines available for audio tasks include the following.
|
||||
- __call__
|
||||
- all
|
||||
|
||||
### TextToAudioPipeline
|
||||
|
||||
[[autodoc]] TextToAudioPipeline
|
||||
- __call__
|
||||
- all
|
||||
|
||||
|
||||
### ZeroShotAudioClassificationPipeline
|
||||
|
||||
[[autodoc]] ZeroShotAudioClassificationPipeline
|
||||
|
@ -643,6 +643,7 @@ _import_structure = {
|
||||
"Text2TextGenerationPipeline",
|
||||
"TextClassificationPipeline",
|
||||
"TextGenerationPipeline",
|
||||
"TextToAudioPipeline",
|
||||
"TokenClassificationPipeline",
|
||||
"TranslationPipeline",
|
||||
"VideoClassificationPipeline",
|
||||
@ -1095,6 +1096,8 @@ else:
|
||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
|
||||
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||
@ -4607,6 +4610,7 @@ if TYPE_CHECKING:
|
||||
Text2TextGenerationPipeline,
|
||||
TextClassificationPipeline,
|
||||
TextGenerationPipeline,
|
||||
TextToAudioPipeline,
|
||||
TokenClassificationPipeline,
|
||||
TranslationPipeline,
|
||||
VideoClassificationPipeline,
|
||||
@ -5007,6 +5011,8 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING,
|
||||
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
|
@ -65,6 +65,8 @@ else:
|
||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
|
||||
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
|
||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||
@ -241,6 +243,8 @@ if TYPE_CHECKING:
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING,
|
||||
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||
|
@ -1014,6 +1014,21 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Text-To-Spectrogram mapping
|
||||
("speecht5", "SpeechT5ForTextToSpeech"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Text-To-Waveform mapping
|
||||
("bark", "BarkModel"),
|
||||
("musicgen", "MusicgenForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Zero Shot Image Classification mapping
|
||||
@ -1152,6 +1167,12 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
)
|
||||
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
|
||||
)
|
||||
|
||||
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
||||
|
||||
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
||||
@ -1407,6 +1428,14 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
||||
|
||||
|
||||
class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
|
||||
|
||||
|
||||
class AutoModelForTextToWaveform(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
||||
|
||||
|
||||
class AutoBackbone(_BaseAutoBackboneClass):
|
||||
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
||||
|
||||
|
@ -60,6 +60,7 @@ else:
|
||||
),
|
||||
),
|
||||
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
||||
(
|
||||
"barthez",
|
||||
@ -224,6 +225,7 @@ else:
|
||||
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
||||
),
|
||||
),
|
||||
("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
|
@ -70,6 +70,7 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab
|
||||
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
||||
from .text_classification import TextClassificationPipeline
|
||||
from .text_generation import TextGenerationPipeline
|
||||
from .text_to_audio import TextToAudioPipeline
|
||||
from .token_classification import (
|
||||
AggregationStrategy,
|
||||
NerPipeline,
|
||||
@ -121,6 +122,8 @@ if is_torch_available():
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForSpeechSeq2Seq,
|
||||
AutoModelForTableQuestionAnswering,
|
||||
AutoModelForTextToSpectrogram,
|
||||
AutoModelForTextToWaveform,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForVideoClassification,
|
||||
AutoModelForVision2Seq,
|
||||
@ -144,6 +147,7 @@ TASK_ALIASES = {
|
||||
"sentiment-analysis": "text-classification",
|
||||
"ner": "token-classification",
|
||||
"vqa": "visual-question-answering",
|
||||
"text-to-speech": "text-to-audio",
|
||||
}
|
||||
SUPPORTED_TASKS = {
|
||||
"audio-classification": {
|
||||
@ -160,6 +164,13 @@ SUPPORTED_TASKS = {
|
||||
"default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
|
||||
"type": "multimodal",
|
||||
},
|
||||
"text-to-audio": {
|
||||
"impl": TextToAudioPipeline,
|
||||
"tf": (),
|
||||
"pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),
|
||||
"default": {"model": {"pt": ("suno/bark-small", "645cfba")}},
|
||||
"type": "text",
|
||||
},
|
||||
"feature-extraction": {
|
||||
"impl": FeatureExtractionPipeline,
|
||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||
@ -386,6 +397,7 @@ SUPPORTED_TASKS = {
|
||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||
NO_IMAGE_PROCESSOR_TASKS = set()
|
||||
NO_TOKENIZER_TASKS = set()
|
||||
|
||||
# Those model configs are special, they are generic over their task, meaning
|
||||
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
||||
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
||||
@ -465,6 +477,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
|
||||
- `"text2text-generation"`
|
||||
- `"text-classification"` (alias `"sentiment-analysis"` available)
|
||||
- `"text-generation"`
|
||||
- `"text-to-audio"` (alias `"text-to-speech"` available)
|
||||
- `"token-classification"` (alias `"ner"` available)
|
||||
- `"translation"`
|
||||
- `"translation_xx_to_yy"`
|
||||
@ -551,6 +564,7 @@ def pipeline(
|
||||
- `"text-classification"` (alias `"sentiment-analysis"` available): will return a
|
||||
[`TextClassificationPipeline`].
|
||||
- `"text-generation"`: will return a [`TextGenerationPipeline`]:.
|
||||
- `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:.
|
||||
- `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`].
|
||||
- `"translation"`: will return a [`TranslationPipeline`].
|
||||
- `"translation_xx_to_yy"`: will return a [`TranslationPipeline`].
|
||||
|
159
src/transformers/pipelines/text_to_audio.py
Normal file
159
src/transformers/pipelines/text_to_audio.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.from typing import List, Union
|
||||
from typing import List, Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
from .base import Pipeline
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
|
||||
from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan
|
||||
|
||||
DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan"
|
||||
|
||||
|
||||
class TextToAudioPipeline(Pipeline):
|
||||
"""
|
||||
Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This
|
||||
pipeline generates an audio file from an input text and optional other conditional inputs.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import pipeline
|
||||
|
||||
>>> classifier = pipeline(model="suno/bark")
|
||||
>>> output = pipeline("Hey it's HuggingFace on the phone!")
|
||||
|
||||
>>> audio = output["audio"]
|
||||
>>> sampling_rate = output["sampling_rate"]
|
||||
```
|
||||
|
||||
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||
|
||||
|
||||
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
|
||||
`"text-to-audio"`.
|
||||
|
||||
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.framework == "tf":
|
||||
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
|
||||
|
||||
self.forward_method = self.model.generate if self.model.can_generate() else self.model
|
||||
|
||||
self.vocoder = None
|
||||
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
|
||||
self.vocoder = (
|
||||
SpeechT5HifiGan.from_pretrained(DEFAULT_VOCODER_ID).to(self.model.device)
|
||||
if vocoder is None
|
||||
else vocoder
|
||||
)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
if self.vocoder is not None:
|
||||
self.sampling_rate = self.vocoder.config.sampling_rate
|
||||
|
||||
if self.sampling_rate is None:
|
||||
# get sampling_rate from config and generation config
|
||||
|
||||
config = self.model.config.to_dict()
|
||||
gen_config = self.model.__dict__.get("generation_config", None)
|
||||
if gen_config is not None:
|
||||
config.update(gen_config.to_dict())
|
||||
|
||||
for sampling_rate_name in ["sample_rate", "sampling_rate"]:
|
||||
sampling_rate = config.get(sampling_rate_name, None)
|
||||
if sampling_rate is not None:
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
def preprocess(self, text, **kwargs):
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
if self.model.config.model_type == "bark":
|
||||
# bark Tokenizer is called with BarkProcessor which uses those kwargs
|
||||
new_kwargs = {
|
||||
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
|
||||
"add_special_tokens": False,
|
||||
"return_attention_mask": True,
|
||||
"return_token_type_ids": False,
|
||||
"padding": "max_length",
|
||||
}
|
||||
|
||||
# priority is given to kwargs
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
kwargs = new_kwargs
|
||||
|
||||
output = self.tokenizer(text, **kwargs, return_tensors="pt")
|
||||
|
||||
return output
|
||||
|
||||
def _forward(self, model_inputs, **kwargs):
|
||||
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
||||
|
||||
# call the generate by defaults or the forward method if the model cannot generate
|
||||
output = self.forward_method(**model_inputs, **kwargs)
|
||||
|
||||
if self.vocoder is not None:
|
||||
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
||||
output = self.vocoder(output)
|
||||
|
||||
return output
|
||||
|
||||
def __call__(self, text_inputs: Union[str, List[str]], **forward_params):
|
||||
"""
|
||||
Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.
|
||||
|
||||
Args:
|
||||
text_inputs (`str` or `List[str]`):
|
||||
The text(s) to generate.
|
||||
forward_params (*optional*):
|
||||
Parameters passed to the model generation/forward method.
|
||||
|
||||
Return:
|
||||
A `dict` or a list of `dict`: The dictionaries have two keys:
|
||||
|
||||
- **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
|
||||
- **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
|
||||
"""
|
||||
return super().__call__(text_inputs, **forward_params)
|
||||
|
||||
def _sanitize_parameters(
|
||||
self,
|
||||
preprocess_params=None,
|
||||
forward_params=None,
|
||||
):
|
||||
if preprocess_params is None:
|
||||
preprocess_params = {}
|
||||
if forward_params is None:
|
||||
forward_params = {}
|
||||
postprocess_params = {}
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
def postprocess(self, waveform):
|
||||
output_dict = {}
|
||||
|
||||
output_dict["audio"] = waveform.cpu().float().numpy()
|
||||
output_dict["sampling_rate"] = self.sampling_rate
|
||||
|
||||
return output_dict
|
@ -527,6 +527,12 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
|
||||
MODEL_FOR_TEXT_ENCODING_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = None
|
||||
|
||||
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
|
190
tests/pipelines/test_pipelines_text_to_audio.py
Normal file
190
tests/pipelines/test_pipelines_text_to_audio.py
Normal file
@ -0,0 +1,190 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
|
||||
AutoProcessor,
|
||||
TextToAudioPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_or_tf,
|
||||
slow,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch_or_tf
|
||||
class TextToAudioPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
||||
# for now only text_to_waveform and not text_to_spectrogram
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_small_model_pt(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||
|
||||
forward_params = {
|
||||
"do_sample": False,
|
||||
"max_new_tokens": 250,
|
||||
}
|
||||
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
|
||||
# musicgen sampling_rate is not straightforward to get
|
||||
self.assertIsNone(outputs["sampling_rate"])
|
||||
|
||||
audio = outputs["audio"]
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# test two examples side-by-side
|
||||
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# test batching
|
||||
outputs = speech_generator(
|
||||
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||
)
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
||||
|
||||
# test text-to-speech
|
||||
|
||||
forward_params = {
|
||||
# Using `do_sample=False` to force deterministic output
|
||||
"do_sample": False,
|
||||
"semantic_max_new_tokens": 100,
|
||||
}
|
||||
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
|
||||
self.assertEqual(
|
||||
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||
outputs,
|
||||
)
|
||||
|
||||
# test two examples side-by-side
|
||||
outputs = speech_generator(
|
||||
["This is a test", "This is a second test"],
|
||||
forward_params=forward_params,
|
||||
)
|
||||
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# test other generation strategy
|
||||
|
||||
forward_params = {
|
||||
"do_sample": True,
|
||||
"semantic_max_new_tokens": 100,
|
||||
"semantic_num_return_sequences": 2,
|
||||
}
|
||||
|
||||
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||
|
||||
audio = outputs["audio"]
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), audio)
|
||||
|
||||
# test using a speaker embedding
|
||||
processor = AutoProcessor.from_pretrained("suno/bark-small")
|
||||
temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5")
|
||||
history_prompt = temp_inp["history_prompt"]
|
||||
forward_params["history_prompt"] = history_prompt
|
||||
|
||||
outputs = speech_generator(
|
||||
["This is a test", "This is a second test"],
|
||||
forward_params=forward_params,
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_conversion_additional_tensor(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt", device=0)
|
||||
processor = AutoProcessor.from_pretrained("suno/bark-small")
|
||||
|
||||
forward_params = {
|
||||
"do_sample": True,
|
||||
"semantic_max_new_tokens": 100,
|
||||
}
|
||||
|
||||
# atm, must do to stay coherent with BarkProcessor
|
||||
preprocess_params = {
|
||||
"max_length": 256,
|
||||
"add_special_tokens": False,
|
||||
"return_attention_mask": True,
|
||||
"return_token_type_ids": False,
|
||||
"padding": "max_length",
|
||||
}
|
||||
|
||||
outputs = speech_generator(
|
||||
"This is a test",
|
||||
forward_params=forward_params,
|
||||
preprocess_params=preprocess_params,
|
||||
)
|
||||
|
||||
temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5")
|
||||
history_prompt = temp_inp["history_prompt"]
|
||||
forward_params["history_prompt"] = history_prompt
|
||||
|
||||
# history_prompt is a torch.Tensor passed as a forward_param
|
||||
# if generation is successfull, it means that it was passed to the right device
|
||||
outputs = speech_generator(
|
||||
"This is a test", forward_params=forward_params, preprocess_params=preprocess_params
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||
outputs,
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||
return speech_generator, ["This is a test", "Another test"]
|
||||
|
||||
def run_pipeline_test(self, speech_generator, _):
|
||||
outputs = speech_generator("This is a test")
|
||||
|
||||
self.assertEqual(ANY(np.ndarray), outputs["audio"])
|
||||
|
||||
forward_params = {"num_return_sequences": 2, "do_sample": True}
|
||||
|
||||
outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params)
|
||||
audio = [output["audio"] for output in outputs]
|
||||
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
@ -49,6 +49,7 @@ from .pipelines.test_pipelines_table_question_answering import TQAPipelineTests
|
||||
from .pipelines.test_pipelines_text2text_generation import Text2TextGenerationPipelineTests
|
||||
from .pipelines.test_pipelines_text_classification import TextClassificationPipelineTests
|
||||
from .pipelines.test_pipelines_text_generation import TextGenerationPipelineTests
|
||||
from .pipelines.test_pipelines_text_to_audio import TextToAudioPipelineTests
|
||||
from .pipelines.test_pipelines_token_classification import TokenClassificationPipelineTests
|
||||
from .pipelines.test_pipelines_translation import TranslationPipelineTests
|
||||
from .pipelines.test_pipelines_video_classification import VideoClassificationPipelineTests
|
||||
@ -78,6 +79,7 @@ pipeline_test_mapping = {
|
||||
"text2text-generation": {"test": Text2TextGenerationPipelineTests},
|
||||
"text-classification": {"test": TextClassificationPipelineTests},
|
||||
"text-generation": {"test": TextGenerationPipelineTests},
|
||||
"text-to-audio": {"test": TextToAudioPipelineTests},
|
||||
"token-classification": {"test": TokenClassificationPipelineTests},
|
||||
"translation": {"test": TranslationPipelineTests},
|
||||
"video-classification": {"test": VideoClassificationPipelineTests},
|
||||
@ -405,6 +407,11 @@ class PipelineTesterMixin:
|
||||
def test_pipeline_text_generation(self):
|
||||
self.run_task_tests(task="text-generation")
|
||||
|
||||
@is_pipeline_test
|
||||
@require_torch
|
||||
def test_pipeline_text_to_audio(self):
|
||||
self.run_task_tests(task="text-to-audio")
|
||||
|
||||
@is_pipeline_test
|
||||
def test_pipeline_token_classification(self):
|
||||
self.run_task_tests(task="token-classification")
|
||||
|
@ -115,6 +115,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
||||
("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_NAMES", "AutoModelForTextToWaveform"),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user