diff --git a/docs/source/en/main_classes/pipelines.md b/docs/source/en/main_classes/pipelines.md index 94272b43834..a3bf5678097 100644 --- a/docs/source/en/main_classes/pipelines.md +++ b/docs/source/en/main_classes/pipelines.md @@ -318,6 +318,13 @@ Pipelines available for audio tasks include the following. - __call__ - all +### TextToAudioPipeline + +[[autodoc]] TextToAudioPipeline + - __call__ + - all + + ### ZeroShotAudioClassificationPipeline [[autodoc]] ZeroShotAudioClassificationPipeline diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 695e51fbe29..b06254a1c0a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -643,6 +643,7 @@ _import_structure = { "Text2TextGenerationPipeline", "TextClassificationPipeline", "TextGenerationPipeline", + "TextToAudioPipeline", "TokenClassificationPipeline", "TranslationPipeline", "VideoClassificationPipeline", @@ -1095,6 +1096,8 @@ else: "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_TEXT_ENCODING_MAPPING", + "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", + "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", @@ -4607,6 +4610,7 @@ if TYPE_CHECKING: Text2TextGenerationPipeline, TextClassificationPipeline, TextGenerationPipeline, + TextToAudioPipeline, TokenClassificationPipeline, TranslationPipeline, VideoClassificationPipeline, @@ -5007,6 +5011,8 @@ if TYPE_CHECKING: MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TEXT_ENCODING_MAPPING, + MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING, + MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 36286f2be0b..3a5095c2173 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -65,6 +65,8 @@ else: "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_TEXT_ENCODING_MAPPING", + "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", + "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", @@ -241,6 +243,8 @@ if TYPE_CHECKING: MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, MODEL_FOR_TEXT_ENCODING_MAPPING, + MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING, + MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index aec9eacc2a7..a0c22f58766 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1014,6 +1014,21 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict( + [ + # Model for Text-To-Spectrogram mapping + ("speecht5", "SpeechT5ForTextToSpeech"), + ] +) + +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict( + [ + # Model for Text-To-Waveform mapping + ("bark", "BarkModel"), + ("musicgen", "MusicgenForConditionalGeneration"), + ] +) + MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Zero Shot Image Classification mapping @@ -1152,6 +1167,12 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( ) MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) +MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES +) + +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES) + MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES) MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES) @@ -1407,6 +1428,14 @@ class AutoModelForAudioXVector(_BaseAutoModelClass): _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING +class AutoModelForTextToSpectrogram(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING + + +class AutoModelForTextToWaveform(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING + + class AutoBackbone(_BaseAutoBackboneClass): _model_mapping = MODEL_FOR_BACKBONE_MAPPING diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e49c7687bd0..5d5f194975f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -60,6 +60,7 @@ else: ), ), ("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ("bark", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("bart", ("BartTokenizer", "BartTokenizerFast")), ( "barthez", @@ -224,6 +225,7 @@ else: "MT5TokenizerFast" if is_tokenizers_available() else None, ), ), + ("musicgen", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)), ("mvp", ("MvpTokenizer", "MvpTokenizerFast" if is_tokenizers_available() else None)), ("nezha", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 2810a4713ad..b1d6ec4bdac 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -70,6 +70,7 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline from .text_classification import TextClassificationPipeline from .text_generation import TextGenerationPipeline +from .text_to_audio import TextToAudioPipeline from .token_classification import ( AggregationStrategy, NerPipeline, @@ -121,6 +122,8 @@ if is_torch_available(): AutoModelForSequenceClassification, AutoModelForSpeechSeq2Seq, AutoModelForTableQuestionAnswering, + AutoModelForTextToSpectrogram, + AutoModelForTextToWaveform, AutoModelForTokenClassification, AutoModelForVideoClassification, AutoModelForVision2Seq, @@ -144,6 +147,7 @@ TASK_ALIASES = { "sentiment-analysis": "text-classification", "ner": "token-classification", "vqa": "visual-question-answering", + "text-to-speech": "text-to-audio", } SUPPORTED_TASKS = { "audio-classification": { @@ -160,6 +164,13 @@ SUPPORTED_TASKS = { "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}}, "type": "multimodal", }, + "text-to-audio": { + "impl": TextToAudioPipeline, + "tf": (), + "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (), + "default": {"model": {"pt": ("suno/bark-small", "645cfba")}}, + "type": "text", + }, "feature-extraction": { "impl": FeatureExtractionPipeline, "tf": (TFAutoModel,) if is_tf_available() else (), @@ -386,6 +397,7 @@ SUPPORTED_TASKS = { NO_FEATURE_EXTRACTOR_TASKS = set() NO_IMAGE_PROCESSOR_TASKS = set() NO_TOKENIZER_TASKS = set() + # Those model configs are special, they are generic over their task, meaning # any tokenizer/feature_extractor might be use for a given model so we cannot # use the statically defined TOKENIZER_MAPPING and FEATURE_EXTRACTOR_MAPPING to @@ -465,6 +477,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]: - `"text2text-generation"` - `"text-classification"` (alias `"sentiment-analysis"` available) - `"text-generation"` + - `"text-to-audio"` (alias `"text-to-speech"` available) - `"token-classification"` (alias `"ner"` available) - `"translation"` - `"translation_xx_to_yy"` @@ -551,6 +564,7 @@ def pipeline( - `"text-classification"` (alias `"sentiment-analysis"` available): will return a [`TextClassificationPipeline`]. - `"text-generation"`: will return a [`TextGenerationPipeline`]:. + - `"text-to-audio"` (alias `"text-to-speech"` available): will return a [`TextToAudioPipeline`]:. - `"token-classification"` (alias `"ner"` available): will return a [`TokenClassificationPipeline`]. - `"translation"`: will return a [`TranslationPipeline`]. - `"translation_xx_to_yy"`: will return a [`TranslationPipeline`]. diff --git a/src/transformers/pipelines/text_to_audio.py b/src/transformers/pipelines/text_to_audio.py new file mode 100644 index 00000000000..5676aa6c210 --- /dev/null +++ b/src/transformers/pipelines/text_to_audio.py @@ -0,0 +1,159 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.from typing import List, Union +from typing import List, Union + +from ..utils import is_torch_available +from .base import Pipeline + + +if is_torch_available(): + from ..models.auto.modeling_auto import MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING + from ..models.speecht5.modeling_speecht5 import SpeechT5HifiGan + +DEFAULT_VOCODER_ID = "microsoft/speecht5_hifigan" + + +class TextToAudioPipeline(Pipeline): + """ + Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. This + pipeline generates an audio file from an input text and optional other conditional inputs. + + Example: + + ```python + >>> from transformers import pipeline + + >>> classifier = pipeline(model="suno/bark") + >>> output = pipeline("Hey it's HuggingFace on the phone!") + + >>> audio = output["audio"] + >>> sampling_rate = output["sampling_rate"] + ``` + + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + + This pipeline can currently be loaded from [`pipeline`] using the following task identifiers: `"text-to-speech"` or + `"text-to-audio"`. + + See the list of available models on [huggingface.co/models](https://huggingface.co/models?filter=text-to-speech). + """ + + def __init__(self, *args, vocoder=None, sampling_rate=None, **kwargs): + super().__init__(*args, **kwargs) + + if self.framework == "tf": + raise ValueError("The TextToAudioPipeline is only available in PyTorch.") + + self.forward_method = self.model.generate if self.model.can_generate() else self.model + + self.vocoder = None + if self.model.__class__ in MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING.values(): + self.vocoder = ( + SpeechT5HifiGan.from_pretrained(DEFAULT_VOCODER_ID).to(self.model.device) + if vocoder is None + else vocoder + ) + + self.sampling_rate = sampling_rate + if self.vocoder is not None: + self.sampling_rate = self.vocoder.config.sampling_rate + + if self.sampling_rate is None: + # get sampling_rate from config and generation config + + config = self.model.config.to_dict() + gen_config = self.model.__dict__.get("generation_config", None) + if gen_config is not None: + config.update(gen_config.to_dict()) + + for sampling_rate_name in ["sample_rate", "sampling_rate"]: + sampling_rate = config.get(sampling_rate_name, None) + if sampling_rate is not None: + self.sampling_rate = sampling_rate + + def preprocess(self, text, **kwargs): + if isinstance(text, str): + text = [text] + + if self.model.config.model_type == "bark": + # bark Tokenizer is called with BarkProcessor which uses those kwargs + new_kwargs = { + "max_length": self.model.generation_config.semantic_config.get("max_input_semantic_length", 256), + "add_special_tokens": False, + "return_attention_mask": True, + "return_token_type_ids": False, + "padding": "max_length", + } + + # priority is given to kwargs + new_kwargs.update(kwargs) + + kwargs = new_kwargs + + output = self.tokenizer(text, **kwargs, return_tensors="pt") + + return output + + def _forward(self, model_inputs, **kwargs): + # we expect some kwargs to be additional tensors which need to be on the right device + kwargs = self._ensure_tensor_on_device(kwargs, device=self.device) + + # call the generate by defaults or the forward method if the model cannot generate + output = self.forward_method(**model_inputs, **kwargs) + + if self.vocoder is not None: + # in that case, the output is a spectrogram that needs to be converted into a waveform + output = self.vocoder(output) + + return output + + def __call__(self, text_inputs: Union[str, List[str]], **forward_params): + """ + Generates speech/audio from the inputs. See the [`TextToAudioPipeline`] documentation for more information. + + Args: + text_inputs (`str` or `List[str]`): + The text(s) to generate. + forward_params (*optional*): + Parameters passed to the model generation/forward method. + + Return: + A `dict` or a list of `dict`: The dictionaries have two keys: + + - **audio** (`np.ndarray` of shape `(nb_channels, audio_length)`) -- The generated audio waveform. + - **sampling_rate** (`int`) -- The sampling rate of the generated audio waveform. + """ + return super().__call__(text_inputs, **forward_params) + + def _sanitize_parameters( + self, + preprocess_params=None, + forward_params=None, + ): + if preprocess_params is None: + preprocess_params = {} + if forward_params is None: + forward_params = {} + postprocess_params = {} + + return preprocess_params, forward_params, postprocess_params + + def postprocess(self, waveform): + output_dict = {} + + output_dict["audio"] = waveform.cpu().float().numpy() + output_dict["sampling_rate"] = self.sampling_rate + + return output_dict diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c27d8c3da97..1e8baba71c1 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -527,6 +527,12 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None MODEL_FOR_TEXT_ENCODING_MAPPING = None +MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = None + + +MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = None + + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None diff --git a/tests/pipelines/test_pipelines_text_to_audio.py b/tests/pipelines/test_pipelines_text_to_audio.py new file mode 100644 index 00000000000..164ec245718 --- /dev/null +++ b/tests/pipelines/test_pipelines_text_to_audio.py @@ -0,0 +1,190 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers import ( + MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, + AutoProcessor, + TextToAudioPipeline, + pipeline, +) +from transformers.testing_utils import ( + is_pipeline_test, + require_torch, + require_torch_gpu, + require_torch_or_tf, + slow, +) + +from .test_pipelines_common import ANY + + +@is_pipeline_test +@require_torch_or_tf +class TextToAudioPipelineTests(unittest.TestCase): + model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING + # for now only text_to_waveform and not text_to_spectrogram + + @slow + @require_torch + def test_small_model_pt(self): + speech_generator = pipeline(task="text-to-audio", model="facebook/musicgen-small", framework="pt") + + forward_params = { + "do_sample": False, + "max_new_tokens": 250, + } + + outputs = speech_generator("This is a test", forward_params=forward_params) + + # musicgen sampling_rate is not straightforward to get + self.assertIsNone(outputs["sampling_rate"]) + + audio = outputs["audio"] + + self.assertEqual(ANY(np.ndarray), audio) + + # test two examples side-by-side + outputs = speech_generator(["This is a test", "This is a second test"], forward_params=forward_params) + + audio = [output["audio"] for output in outputs] + + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + + # test batching + outputs = speech_generator( + ["This is a test", "This is a second test"], forward_params=forward_params, batch_size=2 + ) + + self.assertEqual(ANY(np.ndarray), outputs[0]["audio"]) + + @slow + @require_torch + def test_large_model_pt(self): + speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt") + + # test text-to-speech + + forward_params = { + # Using `do_sample=False` to force deterministic output + "do_sample": False, + "semantic_max_new_tokens": 100, + } + + outputs = speech_generator("This is a test", forward_params=forward_params) + + self.assertEqual( + {"audio": ANY(np.ndarray), "sampling_rate": 24000}, + outputs, + ) + + # test two examples side-by-side + outputs = speech_generator( + ["This is a test", "This is a second test"], + forward_params=forward_params, + ) + + audio = [output["audio"] for output in outputs] + + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + + # test other generation strategy + + forward_params = { + "do_sample": True, + "semantic_max_new_tokens": 100, + "semantic_num_return_sequences": 2, + } + + outputs = speech_generator("This is a test", forward_params=forward_params) + + audio = outputs["audio"] + + self.assertEqual(ANY(np.ndarray), audio) + + # test using a speaker embedding + processor = AutoProcessor.from_pretrained("suno/bark-small") + temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5") + history_prompt = temp_inp["history_prompt"] + forward_params["history_prompt"] = history_prompt + + outputs = speech_generator( + ["This is a test", "This is a second test"], + forward_params=forward_params, + batch_size=2, + ) + + audio = [output["audio"] for output in outputs] + + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) + + @slow + @require_torch_gpu + def test_conversion_additional_tensor(self): + speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt", device=0) + processor = AutoProcessor.from_pretrained("suno/bark-small") + + forward_params = { + "do_sample": True, + "semantic_max_new_tokens": 100, + } + + # atm, must do to stay coherent with BarkProcessor + preprocess_params = { + "max_length": 256, + "add_special_tokens": False, + "return_attention_mask": True, + "return_token_type_ids": False, + "padding": "max_length", + } + + outputs = speech_generator( + "This is a test", + forward_params=forward_params, + preprocess_params=preprocess_params, + ) + + temp_inp = processor("hey, how are you?", voice_preset="v2/en_speaker_5") + history_prompt = temp_inp["history_prompt"] + forward_params["history_prompt"] = history_prompt + + # history_prompt is a torch.Tensor passed as a forward_param + # if generation is successfull, it means that it was passed to the right device + outputs = speech_generator( + "This is a test", forward_params=forward_params, preprocess_params=preprocess_params + ) + + self.assertEqual( + {"audio": ANY(np.ndarray), "sampling_rate": 24000}, + outputs, + ) + + def get_test_pipeline(self, model, tokenizer, processor): + speech_generator = TextToAudioPipeline(model=model, tokenizer=tokenizer) + return speech_generator, ["This is a test", "Another test"] + + def run_pipeline_test(self, speech_generator, _): + outputs = speech_generator("This is a test") + + self.assertEqual(ANY(np.ndarray), outputs["audio"]) + + forward_params = {"num_return_sequences": 2, "do_sample": True} + + outputs = speech_generator(["This is great !", "Something else"], forward_params=forward_params) + audio = [output["audio"] for output in outputs] + + self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index 1fa8f378e40..cf37ec5ad23 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -49,6 +49,7 @@ from .pipelines.test_pipelines_table_question_answering import TQAPipelineTests from .pipelines.test_pipelines_text2text_generation import Text2TextGenerationPipelineTests from .pipelines.test_pipelines_text_classification import TextClassificationPipelineTests from .pipelines.test_pipelines_text_generation import TextGenerationPipelineTests +from .pipelines.test_pipelines_text_to_audio import TextToAudioPipelineTests from .pipelines.test_pipelines_token_classification import TokenClassificationPipelineTests from .pipelines.test_pipelines_translation import TranslationPipelineTests from .pipelines.test_pipelines_video_classification import VideoClassificationPipelineTests @@ -78,6 +79,7 @@ pipeline_test_mapping = { "text2text-generation": {"test": Text2TextGenerationPipelineTests}, "text-classification": {"test": TextClassificationPipelineTests}, "text-generation": {"test": TextGenerationPipelineTests}, + "text-to-audio": {"test": TextToAudioPipelineTests}, "token-classification": {"test": TokenClassificationPipelineTests}, "translation": {"test": TranslationPipelineTests}, "video-classification": {"test": VideoClassificationPipelineTests}, @@ -405,6 +407,11 @@ class PipelineTesterMixin: def test_pipeline_text_generation(self): self.run_task_tests(task="text-generation") + @is_pipeline_test + @require_torch + def test_pipeline_text_to_audio(self): + self.run_task_tests(task="text-to-audio") + @is_pipeline_test def test_pipeline_token_classification(self): self.run_task_tests(task="token-classification") diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 637cdd959c7..c3aec0a868a 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -115,6 +115,7 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"), ("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"), ("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"), + ("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_NAMES", "AutoModelForTextToWaveform"), ]