mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[seamless_m4t] Skip some tests when speech is not available (#38430)
* Added the require_speech decorator * Added require_speecj to some seamless_m4t tests * Changed skip message
This commit is contained in:
parent
64d14ef28d
commit
493cf1554b
@ -130,6 +130,7 @@ from .utils import (
|
||||
is_seqio_available,
|
||||
is_soundfile_available,
|
||||
is_spacy_available,
|
||||
is_speech_available,
|
||||
is_spqr_available,
|
||||
is_sudachi_available,
|
||||
is_sudachi_projection_available,
|
||||
@ -1476,6 +1477,13 @@ def require_tiktoken(test_case):
|
||||
return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
|
||||
|
||||
|
||||
def require_speech(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
|
||||
"""
|
||||
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
|
||||
|
||||
|
||||
def get_gpu_count():
|
||||
"""
|
||||
Return the number of available gpus (regardless of whether torch, tf or jax is used)
|
||||
|
@ -18,7 +18,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SeamlessM4TConfig, is_speech_available, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_speech, require_torch, slow, torch_device
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
@ -1028,6 +1028,7 @@ class SeamlessM4TModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertListAlmostEqual(expected_wav_slice, output.waveform.squeeze().tolist()[50:60])
|
||||
|
||||
@require_speech
|
||||
@slow
|
||||
def test_to_rus_speech(self):
|
||||
model = SeamlessM4TModel.from_pretrained(self.repo_id).to(torch_device)
|
||||
@ -1066,6 +1067,7 @@ class SeamlessM4TModelIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
self.factory_test_task(SeamlessM4TModel, SeamlessM4TForTextToText, self.input_text, kwargs1, kwargs2)
|
||||
|
||||
@require_speech
|
||||
@slow
|
||||
def test_speech_to_text_model(self):
|
||||
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False}
|
||||
@ -1077,6 +1079,7 @@ class SeamlessM4TModelIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
self.factory_test_task(SeamlessM4TModel, SeamlessM4TForSpeechToText, self.input_audio, kwargs1, kwargs2)
|
||||
|
||||
@require_speech
|
||||
@slow
|
||||
def test_speech_to_speech_model(self):
|
||||
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True}
|
||||
|
@ -18,7 +18,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SeamlessM4Tv2Config, is_speech_available, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_speech, require_torch, slow, torch_device
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import cached_property
|
||||
|
||||
@ -1095,6 +1095,7 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
|
||||
[-2.001826e-04, 8.580012e-02], [output.waveform.mean().item(), output.waveform.std().item()]
|
||||
)
|
||||
|
||||
@require_speech
|
||||
@slow
|
||||
def test_to_rus_speech(self):
|
||||
model = SeamlessM4Tv2Model.from_pretrained(self.repo_id).to(torch_device)
|
||||
@ -1139,6 +1140,7 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForTextToText, self.input_text, kwargs1, kwargs2)
|
||||
|
||||
@require_speech
|
||||
@slow
|
||||
def test_speech_to_text_model(self):
|
||||
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True, "generate_speech": False}
|
||||
@ -1150,6 +1152,7 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
|
||||
}
|
||||
self.factory_test_task(SeamlessM4Tv2Model, SeamlessM4Tv2ForSpeechToText, self.input_audio, kwargs1, kwargs2)
|
||||
|
||||
@require_speech
|
||||
@slow
|
||||
def test_speech_to_speech_model(self):
|
||||
kwargs1 = {"tgt_lang": "eng", "return_intermediate_token_ids": True}
|
||||
|
Loading…
Reference in New Issue
Block a user