diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 3c4d413d88e..1519fb02862 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -129,18 +129,13 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching, yielding results similar to cpu computing with 1e-5 tolerance. """ - waveform = torch.from_numpy(waveform).type(torch.float32) + waveform = torch.from_numpy(waveform).to(device, torch.float32) + window = torch.hann_window(self.n_fft, device=device) - window = torch.hann_window(self.n_fft) - if device != "cpu": - waveform = waveform.to(device) - window = window.to(device) stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) magnitudes = stft[..., :-1].abs() ** 2 - mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32) - if device != "cpu": - mel_filters = mel_filters.to(device) + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) mel_spec = mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index ec2e29a41e0..61106c04006 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -298,8 +298,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. ) # fmt: on - input_speech = self._load_datasamples(3) - feature_extractor = WhisperFeatureExtractor() - input_features = feature_extractor(input_speech, return_tensors="pt").input_features + with torch.device("cuda"): + input_speech = self._load_datasamples(3) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="pt").input_features self.assertEqual(input_features.shape, (3, 80, 3000)) torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)