Register feature extractor (#15634)

* Rework AutoFeatureExtractor.from_pretrained internal

* Custom feature extractor

* Add more tests

* Add support for custom feature extractor code

* Clean up

* Add register API to AutoFeatureExtractor
This commit is contained in:
Sylvain Gugger 2022-02-14 13:35:16 -05:00 committed by GitHub
parent 0f71c29053
commit 2e11a04337
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 1 deletions

View File

@ -68,6 +68,10 @@ def feature_extractor_class_from_name(class_name: str):
return getattr(module, class_name) return getattr(module, class_name)
break break
for config, extractor in FEATURE_EXTRACTOR_MAPPING._extra_content.items():
if getattr(extractor, "__name__", None) == class_name:
return extractor
return None return None
@ -301,3 +305,15 @@ class AutoFeatureExtractor:
f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following " f"`feature_extractor_type` key in its {FEATURE_EXTRACTOR_NAME} of {CONFIG_NAME}, or one of the following "
"`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}" "`model_type` keys in its {CONFIG_NAME}: {', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
) )
@staticmethod
def register(config_class, feature_extractor_class):
"""
Register a new feature extractor for this class.
Args:
config_class ([`PretrainedConfig`]):
The configuration corresponding to the model to register.
feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
"""
FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class)

View File

@ -15,13 +15,28 @@
import json import json
import os import os
import sys
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor from transformers import (
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
AutoConfig,
AutoFeatureExtractor,
Wav2Vec2Config,
Wav2Vec2FeatureExtractor,
)
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join( SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
@ -88,3 +103,24 @@ class AutoFeatureExtractorTest(unittest.TestCase):
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
) )
self.assertEqual(model.__class__.__name__, "NewFeatureExtractor") self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")
def test_new_feature_extractor_registration(self):
try:
AutoConfig.register("custom", CustomConfig)
AutoFeatureExtractor.register(CustomConfig, CustomFeatureExtractor)
# Trying to register something existing in the Transformers library will raise an error
with self.assertRaises(ValueError):
AutoFeatureExtractor.register(Wav2Vec2Config, Wav2Vec2FeatureExtractor)
# Now that the config is registered, it can be used as any other config with the auto-API
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(tmp_dir)
new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir)
self.assertIsInstance(new_feature_extractor, CustomFeatureExtractor)
finally:
if "custom" in CONFIG_MAPPING._extra_content:
del CONFIG_MAPPING._extra_content["custom"]
if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]

View File

@ -43,6 +43,9 @@ if is_vision_available():
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False): def prepare_image_inputs(feature_extract_tester, equal_resolution=False, numpify=False, torchify=False):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True. or a list of PyTorch tensors if one specifies torchify=True.