Fix probability computation in WhisperNoSpeechDetection when recomputing scores (#29248)

* Fix is_scores_logprobs in WhisperNoSpeechDetection

* Add test_whisper_longform_no_speech_detection

* Fix typo
This commit is contained in:
Ondřej Cífka 2024-04-03 17:53:07 +02:00 committed by GitHub
parent bcd42c4af9
commit 240e10626b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 1 deletions

View File

@ -1930,6 +1930,8 @@ class WhisperNoSpeechDetection(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
is_scores_logprobs = self.is_scores_logprobs
if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
@ -1937,10 +1939,11 @@ class WhisperNoSpeechDetection(LogitsProcessor):
no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
is_scores_logprobs = False
else:
no_speech_scores = scores
if self.is_scores_logprobs:
if is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)

View File

@ -2670,6 +2670,59 @@ class WhisperModelIntegrationTests(unittest.TestCase):
for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]
@slow
def test_whisper_longform_no_speech_detection(self):
# fmt: off
EXPECTED_TEXT = [
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories. Developing the central headline pawns, definitely maneuvering and also topical night to F6.",
" Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing",
' Ladies and gentlemen, you know, I spent a lot of time right over there raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their joke swollen teats',
' Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the',
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui,",
' You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest.',
" Folks, if you watch this show, you know I spend most of my time right over there, carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most...",
" Folks, if you watch the show and I hope you do, I spent a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines.",
]
# fmt: on
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model = model.to(torch_device)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
num_samples = 8
audio = ds[:num_samples]["audio"]
audios = [x["array"] for x in audio]
# Make sure the second chunk is silent
for audio in audios:
audio[15 * 16000 : 60 * 16000] = 0.0
inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device=torch_device)
gen_kwargs = {
"return_timestamps": True,
"no_speech_threshold": 0.2,
"temperature": (0.0,),
"compression_ratio_threshold": 1.35,
"condition_on_prev_tokens": True,
"logprob_threshold": 0.0, # Ignore logprob, use only no-speech prob
"num_beams": 5,
}
torch.manual_seed(0)
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
for i in range(num_samples):
assert decoded_all[i] == EXPECTED_TEXT[i]
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
if head_mask is None: