diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index efa01a89f6d..ae756c2ceb2 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -312,6 +312,7 @@ _import_structure = { "Wav2Vec2Processor", "Wav2Vec2Tokenizer", ], + "models.wav2vec2_with_lm": [], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], @@ -474,7 +475,7 @@ else: ] if is_pyctcdecode_available(): - _import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM") + _import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM") else: from .utils import dummy_pyctcdecode_objects @@ -2470,7 +2471,7 @@ if TYPE_CHECKING: from .utils.dummy_speech_objects import * if is_pyctcdecode_available(): - from .models.wav2vec2 import Wav2Vec2ProcessorWithLM + from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM else: from .utils.dummy_pyctcdecode_objects import * diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5095d19d913..dee63727817 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -107,6 +107,7 @@ from . import ( visual_bert, vit, wav2vec2, + wav2vec2_with_lm, xlm, xlm_prophetnet, xlm_roberta, diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index df680529661..5a5cf8ac8ad 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -39,6 +39,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("speech_to_text_2", "Speech2Text2Processor"), ("trocr", "TrOCRProcessor"), ("wav2vec2", "Wav2Vec2Processor"), + ("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"), ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), ] ) @@ -145,6 +146,9 @@ class AutoProcessor: key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs } model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs) + # strip to file name + model_files = [f.split("/")[-1] for f in model_files] + if FEATURE_EXTRACTOR_NAME in model_files: config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) if "processor_class" in config_dict: diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index 0ca789825b8..fb03d8a572f 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -17,7 +17,7 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available +from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available _import_structure = { @@ -27,8 +27,6 @@ _import_structure = { "tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"], } -if is_pyctcdecode_available(): - _import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"] if is_torch_available(): _import_structure["modeling_wav2vec2"] = [ @@ -64,9 +62,6 @@ if TYPE_CHECKING: from .processing_wav2vec2 import Wav2Vec2Processor from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer - if is_pyctcdecode_available(): - from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM - if is_torch_available(): from .modeling_wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/models/wav2vec2_with_lm/__init__.py b/src/transformers/models/wav2vec2_with_lm/__init__.py new file mode 100644 index 00000000000..b7f31c5581e --- /dev/null +++ b/src/transformers/models/wav2vec2_with_lm/__init__.py @@ -0,0 +1,36 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_pyctcdecode_available + + +_import_structure = {} + + +if is_pyctcdecode_available(): + _import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"] + + +if TYPE_CHECKING: + if is_pyctcdecode_available(): + from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py similarity index 98% rename from src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py rename to src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index b0acbfbc608..750a49d473f 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -35,8 +35,8 @@ from pyctcdecode.constants import ( from ...feature_extraction_utils import FeatureExtractionMixin from ...file_utils import ModelOutput, requires_backends from ...tokenization_utils import PreTrainedTokenizer -from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor -from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer +from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor +from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer @dataclass @@ -159,6 +159,9 @@ class Wav2Vec2ProcessorWithLM: if os.path.isdir(pretrained_model_name_or_path): decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path) else: + # BeamSearchDecoderCTC has no auto class + kwargs.pop("_from_auto", None) + decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs) # set language model attributes diff --git a/tests/fixtures/dummy_feature_extractor_config.json b/tests/fixtures/dummy_feature_extractor_config.json index a38f627c316..674ef8a0b20 100644 --- a/tests/fixtures/dummy_feature_extractor_config.json +++ b/tests/fixtures/dummy_feature_extractor_config.json @@ -1,3 +1,4 @@ { - "feature_extractor_type": "Wav2Vec2FeatureExtractor" + "feature_extractor_type": "Wav2Vec2FeatureExtractor", + "processor_class": "Wav2Vec2Processor" } diff --git a/tests/test_processor_auto.py b/tests/test_processor_auto.py index 3afc6db2415..a5a953e164a 100644 --- a/tests/test_processor_auto.py +++ b/tests/test_processor_auto.py @@ -16,15 +16,16 @@ import os import tempfile import unittest +from shutil import copyfile from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor +from transformers.file_utils import FEATURE_EXTRACTOR_NAME -SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_PROCESSOR_CONFIG = os.path.join( os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" ) -SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") +SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json") class AutoFeatureExtractorTest(unittest.TestCase): @@ -32,7 +33,7 @@ class AutoFeatureExtractorTest(unittest.TestCase): processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsInstance(processor, Wav2Vec2Processor) - def test_processor_from_local_directory_from_config(self): + def test_processor_from_local_directory_from_repo(self): with tempfile.TemporaryDirectory() as tmpdirname: model_config = Wav2Vec2Config() processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") @@ -44,3 +45,13 @@ class AutoFeatureExtractorTest(unittest.TestCase): processor = AutoProcessor.from_pretrained(tmpdirname) self.assertIsInstance(processor, Wav2Vec2Processor) + + def test_processor_from_local_directory_from_extractor_config(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # copy relevant files + copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME)) + copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json")) + + processor = AutoProcessor.from_pretrained(tmpdirname) + + self.assertIsInstance(processor, Wav2Vec2Processor) diff --git a/tests/test_processor_wav2vec2_with_lm.py b/tests/test_processor_wav2vec2_with_lm.py index 155e09a22eb..14e76d38fd9 100644 --- a/tests/test_processor_wav2vec2_with_lm.py +++ b/tests/test_processor_wav2vec2_with_lm.py @@ -31,7 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list if is_pyctcdecode_available(): from pyctcdecode import BeamSearchDecoderCTC - from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM + from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM @require_pyctcdecode diff --git a/utils/check_inits.py b/utils/check_inits.py index 8cfbfc18a4a..4b1dc574b53 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -27,6 +27,8 @@ _re_backend = re.compile(r"is\_([a-z_]*)_available()") _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') # Catches a line if is_foo_available _re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)") +# Catches a line _import_struct["bla"] = ["foo"] +_re_import_struct_equal_one = re.compile(r'^\s*_import_structure\["\S*"\]\ = "\[(\S*)\]"') # Catches a line _import_struct["bla"].append("foo") _re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') # Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] @@ -88,7 +90,9 @@ def parse_init(init_file): # Until we unindent, add backend objects to the list while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): line = lines[line_index] - if _re_import_struct_add_one.search(line) is not None: + if _re_import_struct_equal_one.search(line) is not None: + objects.append(_re_import_struct_equal_one.search(line).groups()[0]) + elif _re_import_struct_add_one.search(line) is not None: objects.append(_re_import_struct_add_one.search(line).groups()[0]) elif _re_import_struct_add_many.search(line) is not None: imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")