[AutoProcessor] Add Wav2Vec2WithLM & small fix (#14675)

* [AutoProcessor] Add Wav2Vec2WithLM & small fix

* revert line removal

* Update src/transformers/__init__.py

* add test

* up

* up

* small fix
This commit is contained in:
Patrick von Platen 2021-12-08 15:51:28 +01:00 committed by GitHub
parent 2294071a0c
commit ee4fa2e465
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 72 additions and 16 deletions

View File

@ -312,6 +312,7 @@ _import_structure = {
"Wav2Vec2Processor", "Wav2Vec2Processor",
"Wav2Vec2Tokenizer", "Wav2Vec2Tokenizer",
], ],
"models.wav2vec2_with_lm": [],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
@ -474,7 +475,7 @@ else:
] ]
if is_pyctcdecode_available(): if is_pyctcdecode_available():
_import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM") _import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
else: else:
from .utils import dummy_pyctcdecode_objects from .utils import dummy_pyctcdecode_objects
@ -2470,7 +2471,7 @@ if TYPE_CHECKING:
from .utils.dummy_speech_objects import * from .utils.dummy_speech_objects import *
if is_pyctcdecode_available(): if is_pyctcdecode_available():
from .models.wav2vec2 import Wav2Vec2ProcessorWithLM from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else: else:
from .utils.dummy_pyctcdecode_objects import * from .utils.dummy_pyctcdecode_objects import *

View File

@ -107,6 +107,7 @@ from . import (
visual_bert, visual_bert,
vit, vit,
wav2vec2, wav2vec2,
wav2vec2_with_lm,
xlm, xlm,
xlm_prophetnet, xlm_prophetnet,
xlm_roberta, xlm_roberta,

View File

@ -39,6 +39,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("speech_to_text_2", "Speech2Text2Processor"), ("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"), ("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"), ("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
] ]
) )
@ -145,6 +146,9 @@ class AutoProcessor:
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
} }
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs) model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
# strip to file name
model_files = [f.split("/")[-1] for f in model_files]
if FEATURE_EXTRACTOR_NAME in model_files: if FEATURE_EXTRACTOR_NAME in model_files:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "processor_class" in config_dict: if "processor_class" in config_dict:

View File

@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
@ -27,8 +27,6 @@ _import_structure = {
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
} }
if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_wav2vec2"] = [ _import_structure["modeling_wav2vec2"] = [
@ -64,9 +62,6 @@ if TYPE_CHECKING:
from .processing_wav2vec2 import Wav2Vec2Processor from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
if is_torch_available(): if is_torch_available():
from .modeling_wav2vec2 import ( from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -0,0 +1,36 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_pyctcdecode_available
_import_structure = {}
if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
if TYPE_CHECKING:
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)

View File

@ -35,8 +35,8 @@ from pyctcdecode.constants import (
from ...feature_extraction_utils import FeatureExtractionMixin from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import ModelOutput, requires_backends from ...file_utils import ModelOutput, requires_backends
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
@dataclass @dataclass
@ -159,6 +159,9 @@ class Wav2Vec2ProcessorWithLM:
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path) decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
else: else:
# BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None)
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs) decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)
# set language model attributes # set language model attributes

View File

@ -1,3 +1,4 @@
{ {
"feature_extractor_type": "Wav2Vec2FeatureExtractor" "feature_extractor_type": "Wav2Vec2FeatureExtractor",
"processor_class": "Wav2Vec2Processor"
} }

View File

@ -16,15 +16,16 @@
import os import os
import tempfile import tempfile
import unittest import unittest
from shutil import copyfile
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG = os.path.join( SAMPLE_PROCESSOR_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
) )
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
class AutoFeatureExtractorTest(unittest.TestCase): class AutoFeatureExtractorTest(unittest.TestCase):
@ -32,7 +33,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(processor, Wav2Vec2Processor) self.assertIsInstance(processor, Wav2Vec2Processor)
def test_processor_from_local_directory_from_config(self): def test_processor_from_local_directory_from_repo(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config() model_config = Wav2Vec2Config()
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
@ -44,3 +45,13 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(tmpdirname) processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor) self.assertIsInstance(processor, Wav2Vec2Processor)
def test_processor_from_local_directory_from_extractor_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# copy relevant files
copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor)

View File

@ -31,7 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available(): if is_pyctcdecode_available():
from pyctcdecode import BeamSearchDecoderCTC from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
@require_pyctcdecode @require_pyctcdecode

View File

@ -27,6 +27,8 @@ _re_backend = re.compile(r"is\_([a-z_]*)_available()")
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if is_foo_available # Catches a line if is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)") _re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"] = ["foo"]
_re_import_struct_equal_one = re.compile(r'^\s*_import_structure\["\S*"\]\ = "\[(\S*)\]"')
# Catches a line _import_struct["bla"].append("foo") # Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') _re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] # Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
@ -88,7 +90,9 @@ def parse_init(init_file):
# Until we unindent, add backend objects to the list # Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
line = lines[line_index] line = lines[line_index]
if _re_import_struct_add_one.search(line) is not None: if _re_import_struct_equal_one.search(line) is not None:
objects.append(_re_import_struct_equal_one.search(line).groups()[0])
elif _re_import_struct_add_one.search(line) is not None:
objects.append(_re_import_struct_add_one.search(line).groups()[0]) objects.append(_re_import_struct_add_one.search(line).groups()[0])
elif _re_import_struct_add_many.search(line) is not None: elif _re_import_struct_add_many.search(line) is not None:
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ") imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")