[seamless_m4t] Skip some tests when speech is not available (#38430)

* Added the require_speech decorator

* Added require_speecj to some seamless_m4t tests

* Changed skip message
This commit is contained in:
Rémi Ouazan 2025-06-02 11:17:28 +02:00 committed by GitHub
parent 64d14ef28d
commit 493cf1554b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 2 deletions

View File

@ -130,6 +130,7 @@ from .utils import (
is_seqio_available, is_seqio_available,
is_soundfile_available, is_soundfile_available,
is_spacy_available, is_spacy_available,
is_speech_available,
is_spqr_available, is_spqr_available,
is_sudachi_available, is_sudachi_available,
is_sudachi_projection_available, is_sudachi_projection_available,
@ -1476,6 +1477,13 @@ def require_tiktoken(test_case):
return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case) return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
def require_speech(test_case):
"""
Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
"""
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
def get_gpu_count(): def get_gpu_count():
""" """
Return the number of available gpus (regardless of whether torch, tf or jax is used) Return the number of available gpus (regardless of whether torch, tf or jax is used)

View File

@ -18,7 +18,7 @@ import tempfile
import unittest import unittest
from transformers import SeamlessM4TConfig, is_speech_available, is_torch_available from transformers import SeamlessM4TConfig, is_speech_available, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_speech, require_torch, slow, torch_device
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
from transformers.utils import cached_property from transformers.utils import cached_property
@ -1028,6 +1028,7 @@ class SeamlessM4TModelIntegrationTest(unittest.TestCase):
self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60]) self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60])
@require_speech
@slow @slow
def test_to_rus_speech(self): def test_to_rus_speech(self):
model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device) model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device)
@ -1066,6 +1067,7 @@ class SeamlessM4TModelIntegrationTest(unittest.TestCase):
} }
self.factory_test_task(SeamlessM4TModel, SeamlessM4TForTextToText, self.input_text, kwargs1, kwargs2) self.factory_test_task(SeamlessM4TModel, SeamlessM4TForTextToText, self.input_text, kwargs1, kwargs2)
@require_speech
@slow @slow
def test_speech_to_text_model(self): def test_speech_to_text_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False} kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False}
@ -1077,6 +1079,7 @@ class SeamlessM4TModelIntegrationTest(unittest.TestCase):
} }
self.factory_test_task(SeamlessM4TModel, SeamlessM4TForSpeechToText, self.input_audio, kwargs1, kwargs2) self.factory_test_task(SeamlessM4TModel, SeamlessM4TForSpeechToText, self.input_audio, kwargs1, kwargs2)
@require_speech
@slow @slow
def test_speech_to_speech_model(self): def test_speech_to_speech_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True} kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True}

View File

@ -18,7 +18,7 @@ import tempfile
import unittest import unittest
from transformers import SeamlessM4Tv2Config, is_speech_available, is_torch_available from transformers import SeamlessM4Tv2Config, is_speech_available, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_speech, require_torch, slow, torch_device
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
from transformers.utils import cached_property from transformers.utils import cached_property
@ -1095,6 +1095,7 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
[-2.001826e-04, 8.580012e-02], [output.waveform.mean().item(), output.waveform.std().item()] [-2.001826e-04, 8.580012e-02], [output.waveform.mean().item(), output.waveform.std().item()]
) )
@require_speech
@slow @slow
def test_to_rus_speech(self): def test_to_rus_speech(self):
model = SeamlessM4Tv2Model.from_pretrained(self.repo_id).to(torch_device) model = SeamlessM4Tv2Model.from_pretrained(self.repo_id).to(torch_device)
@ -1139,6 +1140,7 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
} }
self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForTextToText, self.input_text, kwargs1, kwargs2) self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForTextToText, self.input_text, kwargs1, kwargs2)
@require_speech
@slow @slow
def test_speech_to_text_model(self): def test_speech_to_text_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False} kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False}
@ -1150,6 +1152,7 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
} }
self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForSpeechToText, self.input_audio, kwargs1, kwargs2) self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForSpeechToText, self.input_audio, kwargs1, kwargs2)
@require_speech
@slow @slow
def test_speech_to_speech_model(self): def test_speech_to_speech_model(self):
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True} kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True}