[Feature Extractors] Fix kwargs to pre-trained (#30260)

fixes
This commit is contained in:
Sanchit Gandhi 2024-04-19 11:16:08 +01:00 committed by GitHub
parent 4ab7a28216
commit cd09a8dfbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 4 deletions

View File

@ -566,17 +566,17 @@ class FeatureExtractionMixin(PushToHubMixin):
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
feature_extractor = cls(**feature_extractor_dict)
# Update feature_extractor with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(feature_extractor, key):
setattr(feature_extractor, key, value)
if key in feature_extractor_dict:
feature_extractor_dict[key] = value
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
feature_extractor = cls(**feature_extractor_dict)
logger.info(f"Feature extractor {feature_extractor}")
if return_unused_kwargs:
return feature_extractor, kwargs

View File

@ -142,6 +142,20 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_from_pretrained_kwargs(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
)
mel_1 = feat_extract_first.mel_filters
mel_2 = feat_extract_second.mel_filters
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())