[AutoProcessor] Add Wav2Vec2WithLM & small fix (#14675)

* [AutoProcessor] Add Wav2Vec2WithLM & small fix

* revert line removal

* Update src/transformers/__init__.py

* add test

* up

* up

* small fix
This commit is contained in:
Patrick von Platen 2021-12-08 15:51:28 +01:00 committed by GitHub
parent 2294071a0c
commit ee4fa2e465
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 16 deletions

View File

@ -312,6 +312,7 @@ _import_structure = {
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
"models.wav2vec2_with_lm": [],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
@ -474,7 +475,7 @@ else:
]
if is_pyctcdecode_available():
_import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM")
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
else:
from .utils import dummy_pyctcdecode_objects
@ -2470,7 +2471,7 @@ if TYPE_CHECKING:
from .utils.dummy_speech_objects import *
if is_pyctcdecode_available():
from .models.wav2vec2 import Wav2Vec2ProcessorWithLM
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
from .utils.dummy_pyctcdecode_objects import *

View File

@ -107,6 +107,7 @@ from . import (
visual_bert,
vit,
wav2vec2,
wav2vec2_with_lm,
xlm,
xlm_prophetnet,
xlm_roberta,

View File

@ -39,6 +39,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
]
)
@ -145,6 +146,9 @@ class AutoProcessor:
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
}
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
# strip to file name
model_files = [f.split("/")[-1] for f in model_files]
if FEATURE_EXTRACTOR_NAME in model_files:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "processor_class" in config_dict:

View File

@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
@ -27,8 +27,6 @@ _import_structure = {
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
}
if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
if is_torch_available():
_import_structure["modeling_wav2vec2"] = [
@ -64,9 +62,6 @@ if TYPE_CHECKING:
from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
if is_torch_available():
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -0,0 +1,36 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_pyctcdecode_available
_import_structure = {}
if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
if TYPE_CHECKING:
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -35,8 +35,8 @@ from pyctcdecode.constants import (
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import ModelOutput, requires_backends
from ...tokenization_utils import PreTrainedTokenizer
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
@dataclass
@ -159,6 +159,9 @@ class Wav2Vec2ProcessorWithLM:
if os.path.isdir(pretrained_model_name_or_path):
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
else:
# BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None)
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)
# set language model attributes

View File

@ -1,3 +1,4 @@
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"processor_class": "Wav2Vec2Processor"
}

View File

@ -16,15 +16,16 @@
import os
import tempfile
import unittest
from shutil import copyfile
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
class AutoFeatureExtractorTest(unittest.TestCase):
@ -32,7 +33,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(processor, Wav2Vec2Processor)
def test_processor_from_local_directory_from_config(self):
def test_processor_from_local_directory_from_repo(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config()
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
@ -44,3 +45,13 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor)
def test_processor_from_local_directory_from_extractor_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# copy relevant files
copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor)

View File

@ -31,7 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available():
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
@require_pyctcdecode

View File

@ -27,6 +27,8 @@ _re_backend = re.compile(r"is\_([a-z_]*)_available()")
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"] = ["foo"]
_re_import_struct_equal_one = re.compile(r'^\s*_import_structure\["\S*"\]\ = "\[(\S*)\]"')
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
@ -88,7 +90,9 @@ def parse_init(init_file):
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
line = lines[line_index]
if _re_import_struct_add_one.search(line) is not None:
if _re_import_struct_equal_one.search(line) is not None:
objects.append(_re_import_struct_equal_one.search(line).groups()[0])
elif _re_import_struct_add_one.search(line) is not None:
objects.append(_re_import_struct_add_one.search(line).groups()[0])
elif _re_import_struct_add_many.search(line) is not None:
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")