Fix device mismatch error in Whisper model during feature extraction (#35866)

* Fix device mismatch error in whisper feature extraction

* Set default device

* Address code review feedback

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
Sumit Vij 2025-02-04 03:23:08 -08:00 committed by GitHub
parent 9afb904b15
commit bc9a6d8302
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 11 deletions

View File

@ -129,18 +129,13 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
yielding results similar to cpu computing with 1e-5 tolerance.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
waveform = torch.from_numpy(waveform).to(device, torch.float32)
window = torch.hann_window(self.n_fft, device=device)
window = torch.hann_window(self.n_fft)
if device != "cpu":
waveform = waveform.to(device)
window = window.to(device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
if device != "cpu":
mel_filters = mel_filters.to(device)
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()

View File

@ -298,8 +298,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
)
# fmt: on
input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
with torch.device("cuda"):
input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (3, 80, 3000))
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)