mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix seamless TTS generate (#34968)
* fix seamless tts generate * apply same fix for v2 * [run-slow] seamless_m4t, seamless_m4t_v2 * remove TODO * [run-slow] seamless_m4t, seamless_m4t_v2 * [run-slow] seamless_m4t, seamless_m4t_v2 * ignore failing test on multigpus * [run-slow] seamless_m4t, seamless_m4t_v2 * [run-slow] seamless_m4t, seamless_m4t_v2
This commit is contained in:
parent
33c12e4d80
commit
6181c6b095
@ -293,6 +293,8 @@ def format_speech_generation_kwargs(kwargs):
|
||||
elif key.startswith("speech_"):
|
||||
key = key[len("speech_") :]
|
||||
kwargs_speech[key] = value
|
||||
elif key == "generation_config":
|
||||
kwargs_text[key] = value
|
||||
else:
|
||||
# If the key is already in a specific config, then it's been set with a
|
||||
# submodules specific value and we don't override
|
||||
|
@ -421,6 +421,8 @@ def format_speech_generation_kwargs(kwargs):
|
||||
elif key.startswith("speech_"):
|
||||
key = key[len("speech_") :]
|
||||
kwargs_speech[key] = value
|
||||
elif key == "generation_config":
|
||||
kwargs_text[key] = value
|
||||
else:
|
||||
# If the key is already in a specific config, then it's been set with a
|
||||
# submodules specific value and we don't override
|
||||
|
@ -589,6 +589,11 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
# TODO: @ydshieh: refer to #34968
|
||||
@unittest.skip(reason="Failing on multi-gpu runner")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class SeamlessM4Tv2ModelWithTextInputTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
@ -27,7 +27,6 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_or_tf,
|
||||
run_test_using_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -67,10 +66,8 @@ class TextToAudioPipelineTests(unittest.TestCase):
|
||||
audio = [output["audio"] for output in outputs]
|
||||
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
|
||||
|
||||
# TODO: @ylacombe: `SeamlessM4TForTextToSpeech.generate` has issue with `generation_config`. See issue #34811
|
||||
@slow
|
||||
@require_torch
|
||||
@run_test_using_subprocess
|
||||
def test_medium_seamless_m4t_pt(self):
|
||||
speech_generator = pipeline(task="text-to-audio", model="facebook/hf-seamless-m4t-medium", framework="pt")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user