From e4decee9c02ac7776508f9fcce10891fb93ada4e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 21 May 2025 14:11:08 +0100 Subject: [PATCH] [whisper] small changes for faster tests (#38236) --- tests/models/whisper/test_modeling_whisper.py | 48 +++++-------------- 1 file changed, 11 insertions(+), 37 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0446fb2052d..2085d9f2844 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -27,7 +27,6 @@ import pytest from huggingface_hub import hf_hub_download from parameterized import parameterized -import transformers from transformers import WhisperConfig from transformers.testing_utils import ( is_flaky, @@ -41,7 +40,7 @@ from transformers.testing_utils import ( slow, torch_device, ) -from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available +from transformers.utils import is_torch_available, is_torch_xpu_available, is_torchaudio_available from transformers.utils.import_utils import is_datasets_available from ...generation.test_utils import GenerationTesterMixin @@ -1432,33 +1431,22 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch @require_torchaudio class WhisperModelIntegrationTests(unittest.TestCase): - def setUp(self): - self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate + _dataset = None - def tearDown(self): - transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate - - @cached_property - def default_processor(self): - return WhisperProcessor.from_pretrained("openai/whisper-base") + @classmethod + def _load_dataset(cls): + # Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process. + if cls._dataset is None: + cls._dataset = datasets.load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) def _load_datasamples(self, num_samples): - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - # automatic decoding with librispeech + self._load_dataset() + ds = self._dataset speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] - return [x["array"] for x in speech_samples] - def _patch_generation_mixin_generate(self, check_args_fn=None): - test = self - - def generate(self, *args, **kwargs): - if check_args_fn is not None: - check_args_fn(*args, **kwargs) - return test._unpatched_generation_mixin_generate(self, *args, **kwargs) - - transformers.GenerationMixin.generate = generate - @slow def test_tiny_logits_librispeech(self): torch_device = "cpu" @@ -1586,8 +1574,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): @slow def test_tiny_en_generation(self): - torch_device = "cpu" - set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") model.to(torch_device) @@ -1605,8 +1591,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): @slow def test_tiny_generation(self): - torch_device = "cpu" - set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model.to(torch_device) @@ -1623,8 +1607,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_generation(self): - torch_device = "cpu" - set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") model.to(torch_device) @@ -1643,7 +1625,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_generation_multilingual(self): - set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") model.to(torch_device) @@ -1710,8 +1691,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): @slow def test_large_batched_generation_multilingual(self): - torch_device = "cpu" - set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-large") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") model.to(torch_device) @@ -2727,11 +2706,6 @@ class WhisperModelIntegrationTests(unittest.TestCase): "renormalize_logits": True, # necessary to match OAI beam search implementation } - def check_gen_kwargs(inputs, generation_config, *args, **kwargs): - self.assertEqual(generation_config.num_beams, gen_kwargs["num_beams"]) - - self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs) - torch.manual_seed(0) result = model.generate(input_features, **gen_kwargs) decoded = processor.batch_decode(result, skip_special_tokens=True)