mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Auto feature extractor (#11097)
* AutoFeatureExtractor * Init and first tests * Tests * Damn you gitignore * Quality * Defensive test for when not all backends are here * Use pattern for Speech2Text models
This commit is contained in:
parent
520198f56f
commit
403d530eec
3
.gitignore
vendored
3
.gitignore
vendored
@ -9,8 +9,7 @@ __pycache__/
|
||||
*.so
|
||||
|
||||
# tests and logs
|
||||
tests/fixtures/*
|
||||
!tests/fixtures/sample_text_no_unicode.txt
|
||||
tests/fixtures/cached_*_text.txt
|
||||
logs/
|
||||
lightning_logs/
|
||||
lang_code_data/
|
||||
|
@ -44,6 +44,13 @@ AutoTokenizer
|
||||
:members:
|
||||
|
||||
|
||||
AutoFeatureExtractor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoFeatureExtractor
|
||||
:members:
|
||||
|
||||
|
||||
AutoModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
@ -45,6 +45,7 @@ from .file_utils import (
|
||||
_BaseLazyModule,
|
||||
is_flax_available,
|
||||
is_sentencepiece_available,
|
||||
is_speech_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
@ -102,6 +103,7 @@ _import_structure = {
|
||||
"is_py3nvml_available",
|
||||
"is_sentencepiece_available",
|
||||
"is_sklearn_available",
|
||||
"is_speech_available",
|
||||
"is_tf_available",
|
||||
"is_tokenizers_available",
|
||||
"is_torch_available",
|
||||
@ -133,9 +135,11 @@ _import_structure = {
|
||||
"models.auto": [
|
||||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"CONFIG_MAPPING",
|
||||
"FEATURE_EXTRACTOR_MAPPING",
|
||||
"MODEL_NAMES_MAPPING",
|
||||
"TOKENIZER_MAPPING",
|
||||
"AutoConfig",
|
||||
"AutoFeatureExtractor",
|
||||
"AutoTokenizer",
|
||||
],
|
||||
"models.bart": ["BartConfig", "BartTokenizer"],
|
||||
@ -202,7 +206,6 @@ _import_structure = {
|
||||
"models.speech_to_text": [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Speech2TextConfig",
|
||||
"Speech2TextFeatureExtractor",
|
||||
],
|
||||
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
|
||||
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
||||
@ -288,7 +291,6 @@ if is_sentencepiece_available():
|
||||
_import_structure["models.pegasus"].append("PegasusTokenizer")
|
||||
_import_structure["models.reformer"].append("ReformerTokenizer")
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
|
||||
_import_structure["models.t5"].append("T5Tokenizer")
|
||||
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
|
||||
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
|
||||
@ -339,6 +341,7 @@ if is_tokenizers_available():
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
|
||||
|
||||
else:
|
||||
from .utils import dummy_tokenizers_objects
|
||||
|
||||
@ -346,6 +349,20 @@ else:
|
||||
name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
# Speech-specific objects
|
||||
if is_speech_available():
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
|
||||
|
||||
else:
|
||||
from .utils import dummy_speech_objects
|
||||
|
||||
_import_structure["utils.dummy_speech_objects"] = [
|
||||
name for name in dir(dummy_speech_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
# Vision-specific objects
|
||||
if is_vision_available():
|
||||
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
|
||||
@ -1394,6 +1411,7 @@ if TYPE_CHECKING:
|
||||
is_py3nvml_available,
|
||||
is_sentencepiece_available,
|
||||
is_sklearn_available,
|
||||
is_speech_available,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
@ -1429,9 +1447,11 @@ if TYPE_CHECKING:
|
||||
from .models.auto import (
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CONFIG_MAPPING,
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
MODEL_NAMES_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from .models.bart import BartConfig, BartTokenizer
|
||||
@ -1494,11 +1514,7 @@ if TYPE_CHECKING:
|
||||
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
|
||||
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
|
||||
from .models.speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
Speech2TextConfig,
|
||||
Speech2TextFeatureExtractor,
|
||||
)
|
||||
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
|
||||
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
|
||||
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
|
||||
@ -1585,7 +1601,7 @@ if TYPE_CHECKING:
|
||||
from .models.mt5 import MT5Tokenizer
|
||||
from .models.pegasus import PegasusTokenizer
|
||||
from .models.reformer import ReformerTokenizer
|
||||
from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer
|
||||
from .models.speech_to_text import Speech2TextTokenizer
|
||||
from .models.t5 import T5Tokenizer
|
||||
from .models.xlm_prophetnet import XLMProphetNetTokenizer
|
||||
from .models.xlm_roberta import XLMRobertaTokenizer
|
||||
@ -1627,9 +1643,19 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
|
||||
|
||||
else:
|
||||
from .utils.dummy_tokenizers_objects import *
|
||||
|
||||
if is_speech_available():
|
||||
from .models.speech_to_text import Speech2TextFeatureExtractor
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .models.speech_to_text import Speech2TextProcessor
|
||||
|
||||
else:
|
||||
from .utils.dummy_speech_objects import *
|
||||
|
||||
if is_vision_available():
|
||||
from .image_utils import ImageFeatureExtractionMixin
|
||||
from .models.vit import ViTFeatureExtractor
|
||||
|
@ -43,6 +43,7 @@ deps = {
|
||||
"sphinx-copybutton": "sphinx-copybutton",
|
||||
"sphinx-markdown-tables": "sphinx-markdown-tables",
|
||||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
|
||||
"sphinx": "sphinx==3.2.1",
|
||||
"starlette": "starlette",
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
||||
|
@ -325,6 +325,13 @@ class FeatureExtractionMixin:
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
|
||||
if from_pipeline is not None:
|
||||
user_agent["using_pipeline"] = from_pipeline
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
@ -349,6 +356,7 @@ class FeatureExtractionMixin:
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
# Load feature_extractor dict
|
||||
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
|
||||
@ -426,6 +434,7 @@ class FeatureExtractionMixin:
|
||||
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["feature_extractor_type"] = self.__class__.__name__
|
||||
|
||||
return output
|
||||
|
||||
|
@ -397,6 +397,11 @@ def is_torchaudio_available():
|
||||
return _torchaudio_available
|
||||
|
||||
|
||||
def is_speech_available():
|
||||
# For now this depends on torchaudio but the exact dependency might evolve in the future.
|
||||
return _torchaudio_available
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
@ -513,6 +518,13 @@ explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/ins
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
SPEECH_IMPORT_ERROR = """
|
||||
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
|
||||
`pip install torchaudio`
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
VISION_IMPORT_ERROR = """
|
||||
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
|
||||
@ -586,6 +598,12 @@ def requires_scatter(obj):
|
||||
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
|
||||
|
||||
|
||||
def requires_speech(obj):
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
if not is_speech_available():
|
||||
raise ImportError(SPEECH_IMPORT_ERROR.format(name))
|
||||
|
||||
|
||||
def requires_vision(obj):
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
if not is_vision_available():
|
||||
|
@ -23,6 +23,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
|
||||
|
||||
_import_structure = {
|
||||
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
|
||||
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
|
||||
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
|
||||
}
|
||||
|
||||
@ -104,6 +105,7 @@ if is_flax_available():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
|
||||
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
|
||||
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
|
||||
if is_torch_available():
|
||||
|
150
src/transformers/models/auto/feature_extraction_auto.py
Normal file
150
src/transformers/models/auto/feature_extraction_auto.py
Normal file
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" AutoFeatureExtractor class. """
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import is_speech_available, is_vision_available
|
||||
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
||||
from .configuration_auto import replace_list_option_in_docstrings
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from ..speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor
|
||||
else:
|
||||
Speech2TextFeatureExtractor = None
|
||||
|
||||
if is_vision_available():
|
||||
from ..vit.feature_extraction_vit import ViTFeatureExtractor
|
||||
else:
|
||||
ViTFeatureExtractor = None
|
||||
|
||||
|
||||
# Build the list of all feature extractors
|
||||
FEATURE_EXTRACTOR_MAPPING = OrderedDict(
|
||||
[
|
||||
("s2t", Speech2TextFeatureExtractor),
|
||||
("vit", ViTFeatureExtractor),
|
||||
("wav2vec2", Wav2Vec2FeatureExtractor),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def feature_extractor_class_from_name(class_name: str):
|
||||
for c in FEATURE_EXTRACTOR_MAPPING.values():
|
||||
if c is not None and c.__name__ == class_name:
|
||||
return c
|
||||
|
||||
|
||||
class AutoFeatureExtractor:
|
||||
r"""
|
||||
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the
|
||||
library when created with the :meth:`AutoFeatureExtractor.from_pretrained` class method.
|
||||
|
||||
This class cannot be instantiated directly using ``__init__()`` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"AutoFeatureExtractor is designed to be instantiated "
|
||||
"using the `AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@replace_list_option_in_docstrings(FEATURE_EXTRACTOR_MAPPING)
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
|
||||
|
||||
The tokenizer class to instantiate is selected based on the :obj:`model_type` property of the config object
|
||||
(either passed as an argument or loaded from :obj:`pretrained_model_name_or_path` if possible), or when it's
|
||||
missing, by falling back to using pattern matching on :obj:`pretrained_model_name_or_path`:
|
||||
|
||||
List options
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
||||
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing a feature extractor file saved using the
|
||||
:func:`~transformers.feature_extraction_utils.FeatureExtractionMixin.save_pretrained` method, e.g.,
|
||||
``./my_model_directory/``.
|
||||
- a path or url to a saved feature extractor JSON `file`, e.g.,
|
||||
``./my_model_directory/feature_extraction_config.json``.
|
||||
cache_dir (:obj:`str` or :obj:`os.PathLike`, `optional`):
|
||||
Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force to (re-)download the feature extractor files and override the cached versions
|
||||
if they exist.
|
||||
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
|
||||
exists.
|
||||
proxies (:obj:`Dict[str, str]`, `optional`):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (:obj:`str` or `bool`, `optional`):
|
||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token
|
||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`).
|
||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If :obj:`False`, then this function returns just the final feature extractor object. If :obj:`True`,
|
||||
then this functions returns a :obj:`Tuple(feature_extractor, unused_kwargs)` where `unused_kwargs` is a
|
||||
dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the
|
||||
part of ``kwargs`` which has not been used to update ``feature_extractor`` and is otherwise ignored.
|
||||
kwargs (:obj:`Dict[str, Any]`, `optional`):
|
||||
The values in kwargs of any keys which are feature extractor attributes will be used to override the
|
||||
loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
|
||||
controlled by the ``return_unused_kwargs`` keyword parameter.
|
||||
|
||||
.. note::
|
||||
|
||||
Passing :obj:`use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import AutoFeatureExtractor
|
||||
|
||||
>>> # Download vocabulary from huggingface.co and cache.
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('facebook/wav2vec2-base-960h')
|
||||
|
||||
>>> # If vocabulary files are in a directory (e.g. feature extractor was saved using `save_pretrained('./test/saved_model/')`)
|
||||
>>> feature_extractor = AutoFeatureExtractor.from_pretrained('./test/saved_model/')
|
||||
|
||||
"""
|
||||
kwargs["_from_auto"] = True
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "feature_extractor_type" in config_dict:
|
||||
feature_extractor_class = feature_extractor_class_from_name(config_dict["feature_extractor_type"])
|
||||
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
||||
else:
|
||||
# Fallback: use pattern matching on the string.
|
||||
for pattern, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items():
|
||||
if pattern in str(pretrained_model_name_or_path):
|
||||
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
|
||||
"its feature_extraction_config.json, or contain one of the following strings "
|
||||
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
|
||||
)
|
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available
|
||||
from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -25,13 +25,17 @@ _import_structure = {
|
||||
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"Speech2TextConfig",
|
||||
],
|
||||
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
|
||||
}
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
|
||||
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
|
||||
|
||||
if is_speech_available():
|
||||
_import_structure["feature_extraction_speech_to_text"] = ["Speech2TextFeatureExtractor"]
|
||||
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_speech_to_text"] = [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -43,11 +47,15 @@ if is_torch_available():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .tokenization_speech_to_text import Speech2TextTokenizer
|
||||
|
||||
if is_speech_available():
|
||||
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .processing_speech_to_text import Speech2TextProcessor
|
||||
from .tokenization_speech_to_text import Speech2TextTokenizer
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_speech_to_text import (
|
||||
|
@ -19,19 +19,15 @@ Feature extractor class for Speech2Text
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...file_utils import PaddingStrategy, TensorType, is_torch_available, is_torchaudio_available
|
||||
from ...file_utils import PaddingStrategy, TensorType
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchaudio_available():
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -75,8 +71,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
normalize_vars=True,
|
||||
**kwargs
|
||||
):
|
||||
if not is_torchaudio_available():
|
||||
raise ImportError("`Speech2TextFeatureExtractor` requires torchaudio: `pip install torchaudio`.")
|
||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.do_ceptral_normalize = do_ceptral_normalize
|
||||
|
@ -110,11 +110,6 @@ class ReformerTokenizer:
|
||||
requires_sentencepiece(self)
|
||||
|
||||
|
||||
class Speech2TextProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
|
||||
|
||||
class Speech2TextTokenizer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
|
12
src/transformers/utils/dummy_speech_objects.py
Normal file
12
src/transformers/utils/dummy_speech_objects.py
Normal file
@ -0,0 +1,12 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..file_utils import requires_speech
|
||||
|
||||
|
||||
class Speech2TextFeatureExtractor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_speech(self)
|
||||
|
||||
|
||||
class Speech2TextProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_speech(self)
|
3
tests/fixtures/dummy_feature_extractor_config.json
vendored
Normal file
3
tests/fixtures/dummy_feature_extractor_config.json
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
|
||||
}
|
44
tests/test_feature_extraction_auto.py
Normal file
44
tests/test_feature_extraction_auto.py
Normal file
@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
||||
)
|
||||
|
||||
|
||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
def test_feature_extractor_from_model_shortcut(self):
|
||||
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_file(self):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_pattern_matching_fallback(self):
|
||||
"""
|
||||
In cases where config.json doesn't include a model_type,
|
||||
perform a few safety checks on the config mapping's order.
|
||||
"""
|
||||
# no key string should be included in a later key string (typical failure case)
|
||||
keys = list(FEATURE_EXTRACTOR_MAPPING.keys())
|
||||
for i, key in enumerate(keys):
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
@ -20,12 +20,15 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import Speech2TextFeatureExtractor
|
||||
from transformers import is_speech_available
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
|
||||
from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import Speech2TextFeatureExtractor
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@ -101,7 +104,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
|
||||
@require_torchaudio
|
||||
class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = Speech2TextFeatureExtractor
|
||||
feature_extraction_class = Speech2TextFeatureExtractor if is_speech_available() else None
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = Speech2TextFeatureExtractionTester(self)
|
||||
|
@ -19,7 +19,7 @@ import unittest
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor, Speech2TextTokenizer
|
||||
from transformers import Speech2TextTokenizer, is_speech_available
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
|
||||
from transformers.models.speech_to_text.tokenization_speech_to_text import VOCAB_FILES_NAMES, save_json
|
||||
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
|
||||
@ -27,6 +27,10 @@ from transformers.testing_utils import require_sentencepiece, require_torch, req
|
||||
from .test_feature_extraction_speech_to_text import floats_list
|
||||
|
||||
|
||||
if is_speech_available():
|
||||
from transformers import Speech2TextFeatureExtractor, Speech2TextProcessor
|
||||
|
||||
|
||||
SAMPLE_SP = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
|
||||
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"]
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
|
||||
|
||||
|
||||
DUMMY_CONSTANT = """
|
||||
|
@ -18,7 +18,7 @@ import re
|
||||
|
||||
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"]
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "speech", "tokenizers", "vision"]
|
||||
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
|
Loading…
Reference in New Issue
Block a user