mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
parent
4ab7a28216
commit
cd09a8dfbc
@ -566,17 +566,17 @@ class FeatureExtractionMixin(PushToHubMixin):
|
||||
"""
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
|
||||
feature_extractor = cls(**feature_extractor_dict)
|
||||
|
||||
# Update feature_extractor with kwargs if needed
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(feature_extractor, key):
|
||||
setattr(feature_extractor, key, value)
|
||||
if key in feature_extractor_dict:
|
||||
feature_extractor_dict[key] = value
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
feature_extractor = cls(**feature_extractor_dict)
|
||||
|
||||
logger.info(f"Feature extractor {feature_extractor}")
|
||||
if return_unused_kwargs:
|
||||
return feature_extractor, kwargs
|
||||
|
@ -142,6 +142,20 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_from_pretrained_kwargs(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(
|
||||
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
|
||||
)
|
||||
|
||||
mel_1 = feat_extract_first.mel_filters
|
||||
mel_2 = feat_extract_second.mel_filters
|
||||
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
Loading…
Reference in New Issue
Block a user