mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Run some TF Whisper tests in subprocesses to avoid GPU OOM (#19772)
* Run some TF Whisper tests in subprocesses to avoid GPU OOM Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
e0b825a8d0
commit
3436842102
@ -17,6 +17,7 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
@ -1672,3 +1673,43 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=600):
|
||||
"""
|
||||
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
|
||||
|
||||
Args:
|
||||
test_case (`unittest.TestCase`):
|
||||
The test that will run `target_func`.
|
||||
target_func (`Callable`):
|
||||
The function implementing the actual testing logic.
|
||||
inputs (`dict`, *optional*, defaults to `None`):
|
||||
The inputs that will be passed to `target_func` through an (input) queue.
|
||||
timeout (`int`, *optional*, defaults to 600):
|
||||
The timeout (in seconds) that will be passed to the input and output queues.
|
||||
"""
|
||||
|
||||
start_methohd = "spawn"
|
||||
ctx = multiprocessing.get_context(start_methohd)
|
||||
|
||||
input_queue = ctx.Queue(1)
|
||||
output_queue = ctx.JoinableQueue(1)
|
||||
|
||||
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
|
||||
input_queue.put(inputs, timeout=timeout)
|
||||
|
||||
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
|
||||
process.start()
|
||||
# Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
|
||||
# the test to exit properly.
|
||||
try:
|
||||
results = output_queue.get(timeout=timeout)
|
||||
output_queue.task_done()
|
||||
except Exception as e:
|
||||
process.terminate()
|
||||
test_case.fail(e)
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if results["error"] is not None:
|
||||
test_case.fail(f'{results["error"]}')
|
||||
|
@ -15,13 +15,15 @@
|
||||
""" Testing suite for the TensorFlow Whisper model. """
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
|
||||
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, slow
|
||||
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, run_test_in_subprocess, slow
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
@ -626,6 +628,184 @@ class TFWhisperModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
|
||||
def _load_datasamples(num_samples):
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
|
||||
def _test_large_logits_librispeech(in_queue, out_queue, timeout):
|
||||
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
set_seed(0)
|
||||
|
||||
model = TFWhisperModel.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = _load_datasamples(1)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf")
|
||||
input_features = processed_inputs.input_features
|
||||
labels = processed_inputs.labels
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=labels,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
|
||||
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
|
||||
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
|
||||
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
unittest.TestCase().assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
def _test_large_generation(in_queue, out_queue, timeout):
|
||||
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = _load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
def _test_large_generation_multilingual(in_queue, out_queue, timeout):
|
||||
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita"
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
unittest.TestCase().assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = _load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
generated_ids_1 = model.generate(input_features[0:2], max_length=20)
|
||||
generated_ids_2 = model.generate(input_features[2:4], max_length=20)
|
||||
generated_ids = np.concatenate([generated_ids_1, generated_ids_2])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
unittest.TestCase().assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
' Mr. Quilter is the apostle of the middle classes and we are glad to',
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
unittest.TestCase().assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_tokenizers
|
||||
class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
@ -634,12 +814,7 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
return WhisperProcessor.from_pretrained("openai/whisper-base")
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
return _load_datasamples(num_samples)
|
||||
|
||||
@slow
|
||||
def test_tiny_logits_librispeech(self):
|
||||
@ -719,40 +894,11 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_large_logits_librispeech(self):
|
||||
set_seed(0)
|
||||
|
||||
model = TFWhisperModel.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf")
|
||||
input_features = processed_inputs.input_features
|
||||
labels = processed_inputs.labels
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=labels,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||
run_test_in_subprocess(
|
||||
test_case=self, target_func=_test_large_logits_librispeech, inputs=None, timeout=timeout
|
||||
)
|
||||
|
||||
logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
|
||||
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
|
||||
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
|
||||
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_tiny_en_generation(self):
|
||||
set_seed(0)
|
||||
@ -816,90 +962,22 @@ class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_large_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_large_generation, inputs=None, timeout=timeout)
|
||||
|
||||
@slow
|
||||
def test_large_generation_multilingual(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
max_length=20,
|
||||
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||
run_test_in_subprocess(
|
||||
test_case=self, target_func=_test_large_generation_multilingual, inputs=None, timeout=timeout
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False, max_length=20)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_batched_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
generated_ids = model.generate(input_features, max_length=20)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
]
|
||||
timeout = os.environ.get("PYTEST_TIMEOUT", 600)
|
||||
run_test_in_subprocess(
|
||||
test_case=self, target_func=_test_large_batched_generation, inputs=None, timeout=timeout
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
' Mr. Quilter is the apostle of the middle classes and we are glad to',
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_en_batched_generation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user