mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix device mismatch error in Whisper model during feature extraction (#35866)
* Fix device mismatch error in whisper feature extraction * Set default device * Address code review feedback --------- Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
This commit is contained in:
parent
9afb904b15
commit
bc9a6d8302
@ -129,18 +129,13 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
|
||||
yielding results similar to cpu computing with 1e-5 tolerance.
|
||||
"""
|
||||
waveform = torch.from_numpy(waveform).type(torch.float32)
|
||||
waveform = torch.from_numpy(waveform).to(device, torch.float32)
|
||||
window = torch.hann_window(self.n_fft, device=device)
|
||||
|
||||
window = torch.hann_window(self.n_fft)
|
||||
if device != "cpu":
|
||||
waveform = waveform.to(device)
|
||||
window = window.to(device)
|
||||
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
||||
if device != "cpu":
|
||||
mel_filters = mel_filters.to(device)
|
||||
mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
|
||||
mel_spec = mel_filters.T @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
|
@ -298,8 +298,9 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(3)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
with torch.device("cuda"):
|
||||
input_speech = self._load_datasamples(3)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
|
||||
self.assertEqual(input_features.shape, (3, 80, 3000))
|
||||
torch.testing.assert_close(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
||||
|
Loading…
Reference in New Issue
Block a user