mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Remove CLI spams with Whisper FeatureExtractor (#21267)
* Remove CLI spams with Whisper FeatureExtractor Whisper feature extractor representation includes the MEL filters, a list of list that is represented as ~16,000 lines. This needlessly spams the command line. I added a `__repr__` method that replaces this list with a string "<array of shape (80, 201)>" * Remove mel_filters from to_dict output Credits to @ArthurZucker * remove unused import * update feature extraction tests for the changes in to_dict
This commit is contained in:
parent
129011c20b
commit
5b72b3412b
@ -15,8 +15,8 @@
|
|||||||
"""
|
"""
|
||||||
Feature extractor class for Whisper
|
Feature extractor class for Whisper
|
||||||
"""
|
"""
|
||||||
|
import copy
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.fft import fft
|
from numpy.fft import fft
|
||||||
@ -322,3 +322,16 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
|||||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||||
|
|
||||||
return padded_inputs
|
return padded_inputs
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||||
|
"""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
output["feature_extractor_type"] = self.__class__.__name__
|
||||||
|
if "mel_filters" in output:
|
||||||
|
del output["mel_filters"]
|
||||||
|
return output
|
||||||
|
@ -128,8 +128,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
|
|
||||||
dict_first = feat_extract_first.to_dict()
|
dict_first = feat_extract_first.to_dict()
|
||||||
dict_second = feat_extract_second.to_dict()
|
dict_second = feat_extract_second.to_dict()
|
||||||
mel_1 = dict_first.pop("mel_filters")
|
mel_1 = feat_extract_first.mel_filters
|
||||||
mel_2 = dict_second.pop("mel_filters")
|
mel_2 = feat_extract_second.mel_filters
|
||||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||||
self.assertEqual(dict_first, dict_second)
|
self.assertEqual(dict_first, dict_second)
|
||||||
|
|
||||||
@ -143,8 +143,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
|||||||
|
|
||||||
dict_first = feat_extract_first.to_dict()
|
dict_first = feat_extract_first.to_dict()
|
||||||
dict_second = feat_extract_second.to_dict()
|
dict_second = feat_extract_second.to_dict()
|
||||||
mel_1 = dict_first.pop("mel_filters")
|
mel_1 = feat_extract_first.mel_filters
|
||||||
mel_2 = dict_second.pop("mel_filters")
|
mel_2 = feat_extract_second.mel_filters
|
||||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||||
self.assertEqual(dict_first, dict_second)
|
self.assertEqual(dict_first, dict_second)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user