From d70fab8b2062526e9c2c60196421a8bc96c7df03 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:00:36 +0100 Subject: [PATCH] [TTA Pipeline] Test MusicGen and VITS (#26146) --- tests/models/musicgen/test_modeling_musicgen.py | 2 +- tests/models/vits/test_modeling_vits.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 77228531614..02ab3b538c2 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -502,7 +502,7 @@ class MusicgenTester: class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else () greedy_sample_model_classes = (MusicgenForConditionalGeneration,) if is_torch_available() else () - pipeline_model_mapping = {} + pipeline_model_mapping = {"text-to-audio": MusicgenForConditionalGeneration} if is_torch_available() else {} test_pruning = False # training is not supported yet for MusicGen test_headmasking = False test_resize_embeddings = False diff --git a/tests/models/vits/test_modeling_vits.py b/tests/models/vits/test_modeling_vits.py index e767b036bea..459a5587cfe 100644 --- a/tests/models/vits/test_modeling_vits.py +++ b/tests/models/vits/test_modeling_vits.py @@ -40,6 +40,7 @@ from ...test_modeling_common import ( ids_tensor, random_attention_mask, ) +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -153,8 +154,9 @@ class VitsModelTester: @require_torch -class VitsModelTest(ModelTesterMixin, unittest.TestCase): +class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (VitsModel,) if is_torch_available() else () + pipeline_model_mapping = {"text-to-audio": VitsModel} if is_torch_available() else {} is_encoder_decoder = False test_pruning = False test_headmasking = False