Remove CLI spams with Whisper FeatureExtractor (#21267)

* Remove CLI spams with Whisper FeatureExtractor

Whisper feature extractor representation includes the MEL filters, a list of list that is represented as ~16,000 lines. This needlessly spams the command line. I added a `__repr__` method that replaces this list with a string "<array of shape (80, 201)>"

* Remove mel_filters from to_dict output  

Credits to @ArthurZucker

* remove unused import

* update feature extraction tests for the changes in to_dict
This commit is contained in:
Quentin Meeus 2023-02-10 15:15:16 +01:00 committed by GitHub
parent 129011c20b
commit 5b72b3412b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 6 deletions

View File

@ -15,8 +15,8 @@
"""
Feature extractor class for Whisper
"""
from typing import List, Optional, Union
import copy
from typing import Any, Dict, List, Optional, Union
import numpy as np
from numpy.fft import fft
@ -322,3 +322,16 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["feature_extractor_type"] = self.__class__.__name__
if "mel_filters" in output:
del output["mel_filters"]
return output

View File

@ -128,8 +128,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = dict_first.pop("mel_filters")
mel_2 = dict_second.pop("mel_filters")
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
@ -143,8 +143,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = dict_first.pop("mel_filters")
mel_2 = dict_second.pop("mel_filters")
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)