mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 14:58:56 +06:00
Add Text-To-Speech pipeline (#24952)
* add AutoModelForTextToSpeech class * add TTS pipeline and tessting * add docstrings to text_to_speech pipeline * fix torch dependency * corrector 'processor is None' case in Pipeline * correct repo id * modify text-to-speech -> text-to-audio * remove processor * rename text_to_speech pipelines files to text_audio * add textToWaveform and textToSpectrogram instead of textToAudio classes * update TTS pipeline to the bare minimum * update tests TTS pipeline * make style and erase useless import torch in TTS pipeline tests * modify how to check if generate or forward in TTS pipeline * remove unnecessary extra new lines * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * refactor input_texts -> text_inputs * correct docstrings of TTS.__call__ * correct the shape of generated waveform * take care of Bark tokenizer special case * correct run_pipeline_test TTS * make style * update TTS docstrings * address Sylvain nit refactors * make style * refactor into one liners * correct squeeze * correct way to test if forward or generate * Update output audio waveform shape * make style * correct import * modify how the TTS pipeline test if a model can generate * align shape output of TTS pipeline with consistent shape --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
parent
c4c0ceff09
commit
b8f69d0d10
@ -318,6 +318,13 @@ Pipelines available for audio tasks include the following.
|
|||||||
- __call__
|
- __call__
|
||||||
- all
|
- all
|
||||||
|
|
||||||
|
### TextToAudioPipeline
|
||||||
|
|
||||||
|
[[autodoc]] TextToAudioPipeline
|
||||||
|
- __call__
|
||||||
|
- all
|
||||||
|
|
||||||
|
|
||||||
### ZeroShotAudioClassificationPipeline
|
### ZeroShotAudioClassificationPipeline
|
||||||
|
|
||||||
[[autodoc]] ZeroShotAudioClassificationPipeline
|
[[autodoc]] ZeroShotAudioClassificationPipeline
|
||||||
|
@ -643,6 +643,7 @@ _import_structure = {
|
|||||||
"Text2TextGenerationPipeline",
|
"Text2TextGenerationPipeline",
|
||||||
"TextClassificationPipeline",
|
"TextClassificationPipeline",
|
||||||
"TextGenerationPipeline",
|
"TextGenerationPipeline",
|
||||||
|
"TextToAudioPipeline",
|
||||||
"TokenClassificationPipeline",
|
"TokenClassificationPipeline",
|
||||||
"TranslationPipeline",
|
"TranslationPipeline",
|
||||||
"VideoClassificationPipeline",
|
"VideoClassificationPipeline",
|
||||||
@ -1095,6 +1096,8 @@ else:
|
|||||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||||
|
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
|
||||||
|
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
|
||||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||||
@ -4607,6 +4610,7 @@ if TYPE_CHECKING:
|
|||||||
Text2TextGenerationPipeline,
|
Text2TextGenerationPipeline,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
TextGenerationPipeline,
|
TextGenerationPipeline,
|
||||||
|
TextToAudioPipeline,
|
||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
TranslationPipeline,
|
TranslationPipeline,
|
||||||
VideoClassificationPipeline,
|
VideoClassificationPipeline,
|
||||||
@ -5007,6 +5011,8 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||||
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING,
|
||||||
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||||
|
@ -65,6 +65,8 @@ else:
|
|||||||
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||||
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
"MODEL_FOR_TEXT_ENCODING_MAPPING",
|
||||||
|
"MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING",
|
||||||
|
"MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING",
|
||||||
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||||
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
|
||||||
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
|
||||||
@ -241,6 +243,8 @@ if TYPE_CHECKING:
|
|||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
||||||
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING,
|
||||||
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING,
|
||||||
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING,
|
||||||
|
@ -1014,6 +1014,21 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Text-To-Spectrogram mapping
|
||||||
|
("speecht5", "SpeechT5ForTextToSpeech"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Text-To-Waveform mapping
|
||||||
|
("bark", "BarkModel"),
|
||||||
|
("musicgen", "MusicgenForConditionalGeneration"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Zero Shot Image Classification mapping
|
# Model for Zero Shot Image Classification mapping
|
||||||
@ -1152,6 +1167,12 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|||||||
)
|
)
|
||||||
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
|
||||||
|
|
||||||
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
|
||||||
|
|
||||||
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
||||||
|
|
||||||
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
||||||
@ -1407,6 +1428,14 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
|
|||||||
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForTextToWaveform(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
||||||
|
|
||||||
|
|
||||||
class AutoBackbone(_BaseAutoBackboneClass):
|
class AutoBackbone(_BaseAutoBackboneClass):
|
||||||
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
||||||
|
|
||||||
|
@ -60,6 +60,7 @@ else:
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
|
("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
("bart", ("BartTokenizer", "BartTokenizerFast")),
|
||||||
(
|
(
|
||||||
"barthez",
|
"barthez",
|
||||||
@ -224,6 +225,7 @@ else:
|
|||||||
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
"MT5TokenizerFast" if is_tokenizers_available() else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
|
||||||
(
|
(
|
||||||
|
@ -70,6 +70,7 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab
|
|||||||
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
|
||||||
from .text_classification import TextClassificationPipeline
|
from .text_classification import TextClassificationPipeline
|
||||||
from .text_generation import TextGenerationPipeline
|
from .text_generation import TextGenerationPipeline
|
||||||
|
from .text_to_audio import TextToAudioPipeline
|
||||||
from .token_classification import (
|
from .token_classification import (
|
||||||
AggregationStrategy,
|
AggregationStrategy,
|
||||||
NerPipeline,
|
NerPipeline,
|
||||||
@ -121,6 +122,8 @@ if is_torch_available():
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForSpeechSeq2Seq,
|
AutoModelForSpeechSeq2Seq,
|
||||||
AutoModelForTableQuestionAnswering,
|
AutoModelForTableQuestionAnswering,
|
||||||
|
AutoModelForTextToSpectrogram,
|
||||||
|
AutoModelForTextToWaveform,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelForVideoClassification,
|
AutoModelForVideoClassification,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@ -144,6 +147,7 @@ TASK_ALIASES = {
|
|||||||
"sentiment-analysis": "text-classification",
|
"sentiment-analysis": "text-classification",
|
||||||
"ner": "token-classification",
|
"ner": "token-classification",
|
||||||
"vqa": "visual-question-answering",
|
"vqa": "visual-question-answering",
|
||||||
|
"text-to-speech": "text-to-audio",
|
||||||
}
|
}
|
||||||
SUPPORTED_TASKS = {
|
SUPPORTED_TASKS = {
|
||||||
"audio-classification": {
|
"audio-classification": {
|
||||||
@ -160,6 +164,13 @@ SUPPORTED_TASKS = {
|
|||||||
"default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
|
"default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},
|
||||||
"type": "multimodal",
|
"type": "multimodal",
|
||||||
},
|
},
|
||||||
|
"text-to-audio": {
|
||||||
|
"impl": TextToAudioPipeline,
|
||||||
|
"tf": (),
|
||||||
|
"pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),
|
||||||
|
"default": {"model": {"pt": ("suno/bark-small", "645cfba")}},
|
||||||
|
"type": "text",
|
||||||
|
},
|
||||||
"feature-extraction": {
|
"feature-extraction": {
|
||||||
"impl": FeatureExtractionPipeline,
|
"impl": FeatureExtractionPipeline,
|
||||||
"tf": (TFAutoModel,) if is_tf_available() else (),
|
"tf": (TFAutoModel,) if is_tf_available() else (),
|
||||||
@ -386,6 +397,7 @@ SUPPORTED_TASKS = {
|
|||||||
NO_FEATURE_EXTRACTOR_TASKS = set()
|
NO_FEATURE_EXTRACTOR_TASKS = set()
|
||||||
NO_IMAGE_PROCESSOR_TASKS = set()
|
NO_IMAGE_PROCESSOR_TASKS = set()
|
||||||
NO_TOKENIZER_TASKS = set()
|
NO_TOKENIZER_TASKS = set()
|
||||||
|
|
||||||
# Those model configs are special, they are generic over their task, meaning
|
# Those model configs are special, they are generic over their task, meaning
|
||||||
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
# any tokenizer/feature_extractor might be use for a given model so we cannot
|
||||||
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
# use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to
|
||||||
@ -465,6 +477,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
|
|||||||
- `"text2text-generation"`
|
- `"text2text-generation"`
|
||||||
- `"text-classification"` (alias `"sentiment-analysis"` available)
|
- `"text-classification"` (alias `"sentiment-analysis"` available)
|
||||||
- `"text-generation"`
|
- `"text-generation"`
|
||||||
|
- `"text-to-audio"` (alias `"text-to-speech"` available)
|
||||||
- `"token-classification"` (alias `"ner"` available)
|
- `"token-classification"` (alias `"ner"` available)
|
||||||
- `"translation"`
|
- `"translation"`
|
||||||
- `"translation_xx_to_yy"`
|
- `"translation_xx_to_yy"`
|
||||||
@ -551,6 +564,7 @@ def pipeline(
|
|||||||
- `"text-classification"` (alias `"sentiment-analysis"` available): will return a
|
- `"text-classification"` (alias `"sentiment-analysis"` available): will return a
|
||||||
[`TextClassificationPipeline`].
|
[`TextClassificationPipeline`].
|
||||||
- `"text-generation"`: will return a [`TextGenerationPipeline`]:.
|
- `"text-generation"`: will return a [`TextGenerationPipeline`]:.
|
||||||
|
- `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:.
|
||||||
- `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`].
|
- `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`].
|
||||||
- `"translation"`: will return a [`TranslationPipeline`].
|
- `"translation"`: will return a [`TranslationPipeline`].
|
||||||
- `"translation_xx_to_yy"`: will return a [`TranslationPipeline`].
|
- `"translation_xx_to_yy"`: will return a [`TranslationPipeline`].
|
||||||
|
159
src/transformers/pipelines/text_to_audio.py
Normal file
159
src/transformers/pipelines/text_to_audio.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.from typing import List, Union
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from ..utils import is_torch_available
|
||||||
|
from .base import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
|
||||||
|
from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan
|
||||||
|
|
||||||
|
DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan"
|
||||||
|
|
||||||
|
|
||||||
|
class TextToAudioPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This
|
||||||
|
pipeline generates an audio file from an input text and optional other conditional inputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import pipeline
|
||||||
|
|
||||||
|
>>> classifier = pipeline(model="suno/bark")
|
||||||
|
>>> output = pipeline("Hey it's HuggingFace on the phone!")
|
||||||
|
|
||||||
|
>>> audio = output["audio"]
|
||||||
|
>>> sampling_rate = output["sampling_rate"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
|
||||||
|
|
||||||
|
|
||||||
|
This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or
|
||||||
|
`"text-to-audio"`.
|
||||||
|
|
||||||
|
See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise ValueError("The TextToAudioPipeline is only available in PyTorch.")
|
||||||
|
|
||||||
|
self.forward_method = self.model.generate if self.model.can_generate() else self.model
|
||||||
|
|
||||||
|
self.vocoder = None
|
||||||
|
if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values():
|
||||||
|
self.vocoder = (
|
||||||
|
SpeechT5HifiGan.from_pretrained(DEFAULT_VOCODER_ID).to(self.model.device)
|
||||||
|
if vocoder is None
|
||||||
|
else vocoder
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
if self.vocoder is not None:
|
||||||
|
self.sampling_rate = self.vocoder.config.sampling_rate
|
||||||
|
|
||||||
|
if self.sampling_rate is None:
|
||||||
|
# get sampling_rate from config and generation config
|
||||||
|
|
||||||
|
config = self.model.config.to_dict()
|
||||||
|
gen_config = self.model.__dict__.get("generation_config", None)
|
||||||
|
if gen_config is not None:
|
||||||
|
config.update(gen_config.to_dict())
|
||||||
|
|
||||||
|
for sampling_rate_name in ["sample_rate", "sampling_rate"]:
|
||||||
|
sampling_rate = config.get(sampling_rate_name, None)
|
||||||
|
if sampling_rate is not None:
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
def preprocess(self, text, **kwargs):
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
if self.model.config.model_type == "bark":
|
||||||
|
# bark Tokenizer is called with BarkProcessor which uses those kwargs
|
||||||
|
new_kwargs = {
|
||||||
|
"max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256),
|
||||||
|
"add_special_tokens": False,
|
||||||
|
"return_attention_mask": True,
|
||||||
|
"return_token_type_ids": False,
|
||||||
|
"padding": "max_length",
|
||||||
|
}
|
||||||
|
|
||||||
|
# priority is given to kwargs
|
||||||
|
new_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
kwargs = new_kwargs
|
||||||
|
|
||||||
|
output = self.tokenizer(text, **kwargs, return_tensors="pt")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _forward(self, model_inputs, **kwargs):
|
||||||
|
# we expect some kwargs to be additional tensors which need to be on the right device
|
||||||
|
kwargs = self._ensure_tensor_on_device(kwargs, device=self.device)
|
||||||
|
|
||||||
|
# call the generate by defaults or the forward method if the model cannot generate
|
||||||
|
output = self.forward_method(**model_inputs, **kwargs)
|
||||||
|
|
||||||
|
if self.vocoder is not None:
|
||||||
|
# in that case, the output is a spectrogram that needs to be converted into a waveform
|
||||||
|
output = self.vocoder(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def __call__(self, text_inputs: Union[str, List[str]], **forward_params):
|
||||||
|
"""
|
||||||
|
Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_inputs (`str` or `List[str]`):
|
||||||
|
The text(s) to generate.
|
||||||
|
forward_params (*optional*):
|
||||||
|
Parameters passed to the model generation/forward method.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A `dict` or a list of `dict`: The dictionaries have two keys:
|
||||||
|
|
||||||
|
- **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform.
|
||||||
|
- **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform.
|
||||||
|
"""
|
||||||
|
return super().__call__(text_inputs, **forward_params)
|
||||||
|
|
||||||
|
def _sanitize_parameters(
|
||||||
|
self,
|
||||||
|
preprocess_params=None,
|
||||||
|
forward_params=None,
|
||||||
|
):
|
||||||
|
if preprocess_params is None:
|
||||||
|
preprocess_params = {}
|
||||||
|
if forward_params is None:
|
||||||
|
forward_params = {}
|
||||||
|
postprocess_params = {}
|
||||||
|
|
||||||
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
|
def postprocess(self, waveform):
|
||||||
|
output_dict = {}
|
||||||
|
|
||||||
|
output_dict["audio"] = waveform.cpu().float().numpy()
|
||||||
|
output_dict["sampling_rate"] = self.sampling_rate
|
||||||
|
|
||||||
|
return output_dict
|
@ -527,6 +527,12 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
|
|||||||
MODEL_FOR_TEXT_ENCODING_MAPPING = None
|
MODEL_FOR_TEXT_ENCODING_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
190
tests/pipelines/test_pipelines_text_to_audio.py
Normal file
190
tests/pipelines/test_pipelines_text_to_audio.py
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING,
|
||||||
|
AutoProcessor,
|
||||||
|
TextToAudioPipeline,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
is_pipeline_test,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_torch_or_tf,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .test_pipelines_common import ANY
|
||||||
|
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_torch_or_tf
|
||||||
|
class TextToAudioPipelineTests(unittest.TestCase):
|
||||||
|
model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
||||||
|
# for now only text_to_waveform and not text_to_spectrogram
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_small_model_pt(self):
|
||||||
|
speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt")
|
||||||
|
|
||||||
|
forward_params = {
|
||||||
|
"do_sample": False,
|
||||||
|
"max_new_tokens": 250,
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
|
||||||
|
# musicgen sampling_rate is not straightforward to get
|
||||||
|
self.assertIsNone(outputs["sampling_rate"])
|
||||||
|
|
||||||
|
audio = outputs["audio"]
|
||||||
|
|
||||||
|
self.assertEqual(ANY(np.ndarray), audio)
|
||||||
|
|
||||||
|
# test two examples side-by-side
|
||||||
|
outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params)
|
||||||
|
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
|
# test batching
|
||||||
|
outputs = speech_generator(
|
||||||
|
["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(ANY(np.ndarray), outputs[0]["audio"])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt")
|
||||||
|
|
||||||
|
# test text-to-speech
|
||||||
|
|
||||||
|
forward_params = {
|
||||||
|
# Using `do_sample=False` to force deterministic output
|
||||||
|
"do_sample": False,
|
||||||
|
"semantic_max_new_tokens": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||||
|
outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test two examples side-by-side
|
||||||
|
outputs = speech_generator(
|
||||||
|
["This is a test", "This is a second test"],
|
||||||
|
forward_params=forward_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
|
# test other generation strategy
|
||||||
|
|
||||||
|
forward_params = {
|
||||||
|
"do_sample": True,
|
||||||
|
"semantic_max_new_tokens": 100,
|
||||||
|
"semantic_num_return_sequences": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = speech_generator("This is a test", forward_params=forward_params)
|
||||||
|
|
||||||
|
audio = outputs["audio"]
|
||||||
|
|
||||||
|
self.assertEqual(ANY(np.ndarray), audio)
|
||||||
|
|
||||||
|
# test using a speaker embedding
|
||||||
|
processor = AutoProcessor.from_pretrained("suno/bark-small")
|
||||||
|
temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5")
|
||||||
|
history_prompt = temp_inp["history_prompt"]
|
||||||
|
forward_params["history_prompt"] = history_prompt
|
||||||
|
|
||||||
|
outputs = speech_generator(
|
||||||
|
["This is a test", "This is a second test"],
|
||||||
|
forward_params=forward_params,
|
||||||
|
batch_size=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_conversion_additional_tensor(self):
|
||||||
|
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt", device=0)
|
||||||
|
processor = AutoProcessor.from_pretrained("suno/bark-small")
|
||||||
|
|
||||||
|
forward_params = {
|
||||||
|
"do_sample": True,
|
||||||
|
"semantic_max_new_tokens": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
# atm, must do to stay coherent with BarkProcessor
|
||||||
|
preprocess_params = {
|
||||||
|
"max_length": 256,
|
||||||
|
"add_special_tokens": False,
|
||||||
|
"return_attention_mask": True,
|
||||||
|
"return_token_type_ids": False,
|
||||||
|
"padding": "max_length",
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = speech_generator(
|
||||||
|
"This is a test",
|
||||||
|
forward_params=forward_params,
|
||||||
|
preprocess_params=preprocess_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5")
|
||||||
|
history_prompt = temp_inp["history_prompt"]
|
||||||
|
forward_params["history_prompt"] = history_prompt
|
||||||
|
|
||||||
|
# history_prompt is a torch.Tensor passed as a forward_param
|
||||||
|
# if generation is successfull, it means that it was passed to the right device
|
||||||
|
outputs = speech_generator(
|
||||||
|
"This is a test", forward_params=forward_params, preprocess_params=preprocess_params
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
{"audio": ANY(np.ndarray), "sampling_rate": 24000},
|
||||||
|
outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_test_pipeline(self, model, tokenizer, processor):
|
||||||
|
speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
return speech_generator, ["This is a test", "Another test"]
|
||||||
|
|
||||||
|
def run_pipeline_test(self, speech_generator, _):
|
||||||
|
outputs = speech_generator("This is a test")
|
||||||
|
|
||||||
|
self.assertEqual(ANY(np.ndarray), outputs["audio"])
|
||||||
|
|
||||||
|
forward_params = {"num_return_sequences": 2, "do_sample": True}
|
||||||
|
|
||||||
|
outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params)
|
||||||
|
audio = [output["audio"] for output in outputs]
|
||||||
|
|
||||||
|
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
@ -49,6 +49,7 @@ from .pipelines.test_pipelines_table_question_answering import TQAPipelineTests
|
|||||||
from .pipelines.test_pipelines_text2text_generation import Text2TextGenerationPipelineTests
|
from .pipelines.test_pipelines_text2text_generation import Text2TextGenerationPipelineTests
|
||||||
from .pipelines.test_pipelines_text_classification import TextClassificationPipelineTests
|
from .pipelines.test_pipelines_text_classification import TextClassificationPipelineTests
|
||||||
from .pipelines.test_pipelines_text_generation import TextGenerationPipelineTests
|
from .pipelines.test_pipelines_text_generation import TextGenerationPipelineTests
|
||||||
|
from .pipelines.test_pipelines_text_to_audio import TextToAudioPipelineTests
|
||||||
from .pipelines.test_pipelines_token_classification import TokenClassificationPipelineTests
|
from .pipelines.test_pipelines_token_classification import TokenClassificationPipelineTests
|
||||||
from .pipelines.test_pipelines_translation import TranslationPipelineTests
|
from .pipelines.test_pipelines_translation import TranslationPipelineTests
|
||||||
from .pipelines.test_pipelines_video_classification import VideoClassificationPipelineTests
|
from .pipelines.test_pipelines_video_classification import VideoClassificationPipelineTests
|
||||||
@ -78,6 +79,7 @@ pipeline_test_mapping = {
|
|||||||
"text2text-generation": {"test": Text2TextGenerationPipelineTests},
|
"text2text-generation": {"test": Text2TextGenerationPipelineTests},
|
||||||
"text-classification": {"test": TextClassificationPipelineTests},
|
"text-classification": {"test": TextClassificationPipelineTests},
|
||||||
"text-generation": {"test": TextGenerationPipelineTests},
|
"text-generation": {"test": TextGenerationPipelineTests},
|
||||||
|
"text-to-audio": {"test": TextToAudioPipelineTests},
|
||||||
"token-classification": {"test": TokenClassificationPipelineTests},
|
"token-classification": {"test": TokenClassificationPipelineTests},
|
||||||
"translation": {"test": TranslationPipelineTests},
|
"translation": {"test": TranslationPipelineTests},
|
||||||
"video-classification": {"test": VideoClassificationPipelineTests},
|
"video-classification": {"test": VideoClassificationPipelineTests},
|
||||||
@ -405,6 +407,11 @@ class PipelineTesterMixin:
|
|||||||
def test_pipeline_text_generation(self):
|
def test_pipeline_text_generation(self):
|
||||||
self.run_task_tests(task="text-generation")
|
self.run_task_tests(task="text-generation")
|
||||||
|
|
||||||
|
@is_pipeline_test
|
||||||
|
@require_torch
|
||||||
|
def test_pipeline_text_to_audio(self):
|
||||||
|
self.run_task_tests(task="text-to-audio")
|
||||||
|
|
||||||
@is_pipeline_test
|
@is_pipeline_test
|
||||||
def test_pipeline_token_classification(self):
|
def test_pipeline_token_classification(self):
|
||||||
self.run_task_tests(task="token-classification")
|
self.run_task_tests(task="token-classification")
|
||||||
|
@ -115,6 +115,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||||
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
||||||
|
("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_NAMES", "AutoModelForTextToWaveform"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user