mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[AutoProcessor] Add Wav2Vec2WithLM & small fix (#14675)
* [AutoProcessor] Add Wav2Vec2WithLM & small fix * revert line removal * Update src/transformers/__init__.py * add test * up * up * small fix
This commit is contained in:
parent
2294071a0c
commit
ee4fa2e465
@ -312,6 +312,7 @@ _import_structure = {
|
|||||||
"Wav2Vec2Processor",
|
"Wav2Vec2Processor",
|
||||||
"Wav2Vec2Tokenizer",
|
"Wav2Vec2Tokenizer",
|
||||||
],
|
],
|
||||||
|
"models.wav2vec2_with_lm": [],
|
||||||
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
||||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||||
@ -474,7 +475,7 @@ else:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
_import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM")
|
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
|
||||||
else:
|
else:
|
||||||
from .utils import dummy_pyctcdecode_objects
|
from .utils import dummy_pyctcdecode_objects
|
||||||
|
|
||||||
@ -2470,7 +2471,7 @@ if TYPE_CHECKING:
|
|||||||
from .utils.dummy_speech_objects import *
|
from .utils.dummy_speech_objects import *
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
from .models.wav2vec2 import Wav2Vec2ProcessorWithLM
|
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_pyctcdecode_objects import *
|
from .utils.dummy_pyctcdecode_objects import *
|
||||||
|
|
||||||
|
@ -107,6 +107,7 @@ from . import (
|
|||||||
visual_bert,
|
visual_bert,
|
||||||
vit,
|
vit,
|
||||||
wav2vec2,
|
wav2vec2,
|
||||||
|
wav2vec2_with_lm,
|
||||||
xlm,
|
xlm,
|
||||||
xlm_prophetnet,
|
xlm_prophetnet,
|
||||||
xlm_roberta,
|
xlm_roberta,
|
||||||
|
@ -39,6 +39,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
|
|||||||
("speech_to_text_2", "Speech2Text2Processor"),
|
("speech_to_text_2", "Speech2Text2Processor"),
|
||||||
("trocr", "TrOCRProcessor"),
|
("trocr", "TrOCRProcessor"),
|
||||||
("wav2vec2", "Wav2Vec2Processor"),
|
("wav2vec2", "Wav2Vec2Processor"),
|
||||||
|
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
|
||||||
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -145,6 +146,9 @@ class AutoProcessor:
|
|||||||
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
|
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
|
||||||
}
|
}
|
||||||
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
|
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
|
||||||
|
# strip to file name
|
||||||
|
model_files = [f.split("/")[-1] for f in model_files]
|
||||||
|
|
||||||
if FEATURE_EXTRACTOR_NAME in model_files:
|
if FEATURE_EXTRACTOR_NAME in model_files:
|
||||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
if "processor_class" in config_dict:
|
if "processor_class" in config_dict:
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available
|
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
@ -27,8 +27,6 @@ _import_structure = {
|
|||||||
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
|
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
|
||||||
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_wav2vec2"] = [
|
_import_structure["modeling_wav2vec2"] = [
|
||||||
@ -64,9 +62,6 @@ if TYPE_CHECKING:
|
|||||||
from .processing_wav2vec2 import Wav2Vec2Processor
|
from .processing_wav2vec2 import Wav2Vec2Processor
|
||||||
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
|
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
|
||||||
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_wav2vec2 import (
|
from .modeling_wav2vec2 import (
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
36
src/transformers/models/wav2vec2_with_lm/__init__.py
Normal file
36
src/transformers/models/wav2vec2_with_lm/__init__.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...file_utils import _LazyModule, is_pyctcdecode_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
|
||||||
|
if is_pyctcdecode_available():
|
||||||
|
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
if is_pyctcdecode_available():
|
||||||
|
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
@ -35,8 +35,8 @@ from pyctcdecode.constants import (
|
|||||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||||
from ...file_utils import ModelOutput, requires_backends
|
from ...file_utils import ModelOutput, requires_backends
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
||||||
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -159,6 +159,9 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
if os.path.isdir(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
|
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
|
||||||
else:
|
else:
|
||||||
|
# BeamSearchDecoderCTC has no auto class
|
||||||
|
kwargs.pop("_from_auto", None)
|
||||||
|
|
||||||
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)
|
decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
# set language model attributes
|
# set language model attributes
|
@ -1,3 +1,4 @@
|
|||||||
{
|
{
|
||||||
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
|
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
||||||
|
"processor_class": "Wav2Vec2Processor"
|
||||||
}
|
}
|
||||||
|
@ -16,15 +16,16 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
from shutil import copyfile
|
||||||
|
|
||||||
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
|
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
|
||||||
|
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
|
||||||
SAMPLE_PROCESSOR_CONFIG = os.path.join(
|
SAMPLE_PROCESSOR_CONFIG = os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
||||||
)
|
)
|
||||||
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
||||||
|
|
||||||
|
|
||||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||||
@ -32,7 +33,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||||
|
|
||||||
def test_processor_from_local_directory_from_config(self):
|
def test_processor_from_local_directory_from_repo(self):
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model_config = Wav2Vec2Config()
|
model_config = Wav2Vec2Config()
|
||||||
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
@ -44,3 +45,13 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||||
|
|
||||||
|
def test_processor_from_local_directory_from_extractor_config(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
# copy relevant files
|
||||||
|
copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
|
||||||
|
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||||
|
@ -31,7 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
|
|||||||
|
|
||||||
if is_pyctcdecode_available():
|
if is_pyctcdecode_available():
|
||||||
from pyctcdecode import BeamSearchDecoderCTC
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM
|
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||||
|
|
||||||
|
|
||||||
@require_pyctcdecode
|
@require_pyctcdecode
|
||||||
|
@ -27,6 +27,8 @@ _re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
|||||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||||
# Catches a line if is_foo_available
|
# Catches a line if is_foo_available
|
||||||
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
|
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
|
||||||
|
# Catches a line _import_struct["bla"] = ["foo"]
|
||||||
|
_re_import_struct_equal_one = re.compile(r'^\s*_import_structure\["\S*"\]\ = "\[(\S*)\]"')
|
||||||
# Catches a line _import_struct["bla"].append("foo")
|
# Catches a line _import_struct["bla"].append("foo")
|
||||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||||
@ -88,7 +90,9 @@ def parse_init(init_file):
|
|||||||
# Until we unindent, add backend objects to the list
|
# Until we unindent, add backend objects to the list
|
||||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||||
line = lines[line_index]
|
line = lines[line_index]
|
||||||
if _re_import_struct_add_one.search(line) is not None:
|
if _re_import_struct_equal_one.search(line) is not None:
|
||||||
|
objects.append(_re_import_struct_equal_one.search(line).groups()[0])
|
||||||
|
elif _re_import_struct_add_one.search(line) is not None:
|
||||||
objects.append(_re_import_struct_add_one.search(line).groups()[0])
|
objects.append(_re_import_struct_add_one.search(line).groups()[0])
|
||||||
elif _re_import_struct_add_many.search(line) is not None:
|
elif _re_import_struct_add_many.search(line) is not None:
|
||||||
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
|
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
|
||||||
|
Loading…
Reference in New Issue
Block a user