diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 77c85a2477a..9526e2815a4 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -15,8 +15,8 @@ """ Feature extractor class for Whisper """ - -from typing import List, Optional, Union +import copy +from typing import Any, Dict, List, Optional, Union import numpy as np from numpy.fft import fft @@ -322,3 +322,16 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): padded_inputs = padded_inputs.convert_to_tensors(return_tensors) return padded_inputs + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + return output diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index b490556c5fe..f1ef36b1f28 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -128,8 +128,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. dict_first = feat_extract_first.to_dict() dict_second = feat_extract_second.to_dict() - mel_1 = dict_first.pop("mel_filters") - mel_2 = dict_second.pop("mel_filters") + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters self.assertTrue(np.allclose(mel_1, mel_2)) self.assertEqual(dict_first, dict_second) @@ -143,8 +143,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. dict_first = feat_extract_first.to_dict() dict_second = feat_extract_second.to_dict() - mel_1 = dict_first.pop("mel_filters") - mel_2 = dict_second.pop("mel_filters") + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters self.assertTrue(np.allclose(mel_1, mel_2)) self.assertEqual(dict_first, dict_second)