mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[whisper] small changes for faster tests (#38236)
This commit is contained in:
parent
ddf67d2d73
commit
e4decee9c0
@ -27,7 +27,6 @@ import pytest
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
import transformers
|
|
||||||
from transformers import WhisperConfig
|
from transformers import WhisperConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
@ -41,7 +40,7 @@ from transformers.testing_utils import (
|
|||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import cached_property, is_torch_available, is_torch_xpu_available, is_torchaudio_available
|
from transformers.utils import is_torch_available, is_torch_xpu_available, is_torchaudio_available
|
||||||
from transformers.utils.import_utils import is_datasets_available
|
from transformers.utils.import_utils import is_datasets_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@ -1432,33 +1431,22 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
class WhisperModelIntegrationTests(unittest.TestCase):
|
class WhisperModelIntegrationTests(unittest.TestCase):
|
||||||
def setUp(self):
|
_dataset = None
|
||||||
self._unpatched_generation_mixin_generate = transformers.GenerationMixin.generate
|
|
||||||
|
|
||||||
def tearDown(self):
|
@classmethod
|
||||||
transformers.GenerationMixin.generate = self._unpatched_generation_mixin_generate
|
def _load_dataset(cls):
|
||||||
|
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
|
||||||
@cached_property
|
if cls._dataset is None:
|
||||||
def default_processor(self):
|
cls._dataset = datasets.load_dataset(
|
||||||
return WhisperProcessor.from_pretrained("openai/whisper-base")
|
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
|
||||||
|
)
|
||||||
|
|
||||||
def _load_datasamples(self, num_samples):
|
def _load_datasamples(self, num_samples):
|
||||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
self._load_dataset()
|
||||||
# automatic decoding with librispeech
|
ds = self._dataset
|
||||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||||
|
|
||||||
return [x["array"] for x in speech_samples]
|
return [x["array"] for x in speech_samples]
|
||||||
|
|
||||||
def _patch_generation_mixin_generate(self, check_args_fn=None):
|
|
||||||
test = self
|
|
||||||
|
|
||||||
def generate(self, *args, **kwargs):
|
|
||||||
if check_args_fn is not None:
|
|
||||||
check_args_fn(*args, **kwargs)
|
|
||||||
return test._unpatched_generation_mixin_generate(self, *args, **kwargs)
|
|
||||||
|
|
||||||
transformers.GenerationMixin.generate = generate
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_logits_librispeech(self):
|
def test_tiny_logits_librispeech(self):
|
||||||
torch_device = "cpu"
|
torch_device = "cpu"
|
||||||
@ -1586,8 +1574,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_en_generation(self):
|
def test_tiny_en_generation(self):
|
||||||
torch_device = "cpu"
|
|
||||||
set_seed(0)
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -1605,8 +1591,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_generation(self):
|
def test_tiny_generation(self):
|
||||||
torch_device = "cpu"
|
|
||||||
set_seed(0)
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -1623,8 +1607,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_generation(self):
|
def test_large_generation(self):
|
||||||
torch_device = "cpu"
|
|
||||||
set_seed(0)
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -1643,7 +1625,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_generation_multilingual(self):
|
def test_large_generation_multilingual(self):
|
||||||
set_seed(0)
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -1710,8 +1691,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_batched_generation_multilingual(self):
|
def test_large_batched_generation_multilingual(self):
|
||||||
torch_device = "cpu"
|
|
||||||
set_seed(0)
|
|
||||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -2727,11 +2706,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
"renormalize_logits": True, # necessary to match OAI beam search implementation
|
"renormalize_logits": True, # necessary to match OAI beam search implementation
|
||||||
}
|
}
|
||||||
|
|
||||||
def check_gen_kwargs(inputs, generation_config, *args, **kwargs):
|
|
||||||
self.assertEqual(generation_config.num_beams, gen_kwargs["num_beams"])
|
|
||||||
|
|
||||||
self._patch_generation_mixin_generate(check_args_fn=check_gen_kwargs)
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
result = model.generate(input_features, **gen_kwargs)
|
result = model.generate(input_features, **gen_kwargs)
|
||||||
decoded = processor.batch_decode(result, skip_special_tokens=True)
|
decoded = processor.batch_decode(result, skip_special_tokens=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user