[whisper] small changes for faster tests (#38236)

This commit is contained in:
Joao Gante 2025-05-21 14:11:08 +01:00 committed by GitHub
parent ddf67d2d73
commit e4decee9c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,7 +27,6 @@ import pytest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from parameterized import parameterized from parameterized import parameterized
import transformers
from transformers import WhisperConfig from transformers import WhisperConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky, is_flaky,
@ -41,7 +40,7 @@ from transformers.testing_utils import (
slow, slow,
torch_device, torch_device,
) )
from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available from transformers.utils import is_torch_available, is_torch_xpu_available, is_torchaudio_available
from transformers.utils.import_utils import is_datasets_available from transformers.utils.import_utils import is_datasets_available
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
@ -1432,33 +1431,22 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
@require_torch @require_torch
@require_torchaudio @require_torchaudio
class WhisperModelIntegrationTests(unittest.TestCase): class WhisperModelIntegrationTests(unittest.TestCase):
def setUp(self): _dataset = None
self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate
def tearDown(self): @classmethod
transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate def _load_dataset(cls):
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
@cached_property if cls._dataset is None:
def default_processor(self): cls._dataset = datasets.load_dataset(
return WhisperProcessor.from_pretrained("openai/whisper-base") "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
def _load_datasamples(self, num_samples): def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") self._load_dataset()
# automatic decoding with librispeech ds = self._dataset
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples] return [x["array"] for x in speech_samples]
def _patch_generation_mixin_generate(self, check_args_fn=None):
test = self
def generate(self, *args, **kwargs):
if check_args_fn is not None:
check_args_fn(*args, **kwargs)
return test._unpatched_generation_mixin_generate(self, *args, **kwargs)
transformers.GenerationMixin.generate = generate
@slow @slow
def test_tiny_logits_librispeech(self): def test_tiny_logits_librispeech(self):
torch_device = "cpu" torch_device = "cpu"
@ -1586,8 +1574,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_tiny_en_generation(self): def test_tiny_en_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device) model.to(torch_device)
@ -1605,8 +1591,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_tiny_generation(self): def test_tiny_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device) model.to(torch_device)
@ -1623,8 +1607,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_generation(self): def test_large_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device) model.to(torch_device)
@ -1643,7 +1625,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_generation_multilingual(self): def test_large_generation_multilingual(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device) model.to(torch_device)
@ -1710,8 +1691,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
@slow @slow
def test_large_batched_generation_multilingual(self): def test_large_batched_generation_multilingual(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large") processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
model.to(torch_device) model.to(torch_device)
@ -2727,11 +2706,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
"renormalize_logits": True, # necessary to match OAI beam search implementation "renormalize_logits": True, # necessary to match OAI beam search implementation
} }
def check_gen_kwargs(inputs, generation_config, *args, **kwargs):
self.assertEqual(generation_config.num_beams, gen_kwargs["num_beams"])
self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs)
torch.manual_seed(0) torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs) result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result, skip_special_tokens=True) decoded = processor.batch_decode(result, skip_special_tokens=True)