Remove CLI spams with Whisper FeatureExtractor (#21267)

* Remove CLI spams with Whisper FeatureExtractor

Whisper feature extractor representation includes the MEL filters, a list of list that is represented as ~16,000 lines. This needlessly spams the command line. I added a `__repr__` method that replaces this list with a string "<array of shape (80, 201)>"

* Remove mel_filters from to_dict output  

Credits to @ArthurZucker

* remove unused import

* update feature extraction tests for the changes in to_dict
This commit is contained in:
Quentin Meeus 2023-02-10 15:15:16 +01:00 committed by GitHub
parent 129011c20b
commit 5b72b3412b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 6 deletions

View File

@ -15,8 +15,8 @@
""" """
Feature extractor class for Whisper Feature extractor class for Whisper
""" """
import copy
from typing import List, Optional, Union from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
from numpy.fft import fft from numpy.fft import fft
@ -322,3 +322,16 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs return padded_inputs
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
if "mel_filters" in output:
del output["mel_filters"]
return output

View File

@ -128,8 +128,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
dict_first = feat_extract_first.to_dict() dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict() dict_second = feat_extract_second.to_dict()
mel_1 = dict_first.pop("mel_filters") mel_1 = feat_extract_first.mel_filters
mel_2 = dict_second.pop("mel_filters") mel_2 = feat_extract_second.mel_filters
self.assertTrue(np.allclose(mel_1, mel_2)) self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second) self.assertEqual(dict_first, dict_second)
@ -143,8 +143,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
dict_first = feat_extract_first.to_dict() dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict() dict_second = feat_extract_second.to_dict()
mel_1 = dict_first.pop("mel_filters") mel_1 = feat_extract_first.mel_filters
mel_2 = dict_second.pop("mel_filters") mel_2 = feat_extract_second.mel_filters
self.assertTrue(np.allclose(mel_1, mel_2)) self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second) self.assertEqual(dict_first, dict_second)