From 5b72b3412b8efb3177d2daf90798e2ee20cb21e9 Mon Sep 17 00:00:00 2001 From: Quentin Meeus <25608944+qmeeus@users.noreply.github.com> Date: Fri, 10 Feb 2023 15:15:16 +0100 Subject: [PATCH] Remove CLI spams with Whisper FeatureExtractor (#21267) * Remove CLI spams with Whisper FeatureExtractor Whisper feature extractor representation includes the MEL filters, a list of list that is represented as ~16,000 lines. This needlessly spams the command line. I added a `__repr__` method that replaces this list with a string "" * Remove mel_filters from to_dict output Credits to @ArthurZucker * remove unused import * update feature extraction tests for the changes in to_dict --- .../whisper/feature_extraction_whisper.py | 17 +++++++++++++++-- .../whisper/test_feature_extraction_whisper.py | 8 ++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 77c85a2477a..9526e2815a4 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -15,8 +15,8 @@ """ Feature extractor class for Whisper """ - -from typing import List, Optional, Union +import copy +from typing import Any, Dict, List, Optional, Union import numpy as np from numpy.fft import fft @@ -322,3 +322,16 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): padded_inputs = padded_inputs.convert_to_tensors(return_tensors) return padded_inputs + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["feature_extractor_type"] = self.__class__.__name__ + if "mel_filters" in output: + del output["mel_filters"] + return output diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index b490556c5fe..f1ef36b1f28 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -128,8 +128,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. dict_first = feat_extract_first.to_dict() dict_second = feat_extract_second.to_dict() - mel_1 = dict_first.pop("mel_filters") - mel_2 = dict_second.pop("mel_filters") + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters self.assertTrue(np.allclose(mel_1, mel_2)) self.assertEqual(dict_first, dict_second) @@ -143,8 +143,8 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest. dict_first = feat_extract_first.to_dict() dict_second = feat_extract_second.to_dict() - mel_1 = dict_first.pop("mel_filters") - mel_2 = dict_second.pop("mel_filters") + mel_1 = feat_extract_first.mel_filters + mel_2 = feat_extract_second.mel_filters self.assertTrue(np.allclose(mel_1, mel_2)) self.assertEqual(dict_first, dict_second)