mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove CLI spams with Whisper FeatureExtractor (#21267)
* Remove CLI spams with Whisper FeatureExtractor Whisper feature extractor representation includes the MEL filters, a list of list that is represented as ~16,000 lines. This needlessly spams the command line. I added a `__repr__` method that replaces this list with a string "<array of shape (80, 201)>" * Remove mel_filters from to_dict output Credits to @ArthurZucker * remove unused import * update feature extraction tests for the changes in to_dict
This commit is contained in:
parent
129011c20b
commit
5b72b3412b
@ -15,8 +15,8 @@
|
||||
"""
|
||||
Feature extractor class for Whisper
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.fft import fft
|
||||
@ -322,3 +322,16 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["feature_extractor_type"] = self.__class__.__name__
|
||||
if "mel_filters" in output:
|
||||
del output["mel_filters"]
|
||||
return output
|
||||
|
@ -128,8 +128,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
mel_1 = dict_first.pop("mel_filters")
|
||||
mel_2 = dict_second.pop("mel_filters")
|
||||
mel_1 = feat_extract_first.mel_filters
|
||||
mel_2 = feat_extract_second.mel_filters
|
||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
@ -143,8 +143,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
|
||||
dict_first = feat_extract_first.to_dict()
|
||||
dict_second = feat_extract_second.to_dict()
|
||||
mel_1 = dict_first.pop("mel_filters")
|
||||
mel_2 = dict_second.pop("mel_filters")
|
||||
mel_1 = feat_extract_first.mel_filters
|
||||
mel_2 = feat_extract_second.mel_filters
|
||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user