* add raw scaffold

* implement feat extract layers

* make style

* remove +

* correctly convert weights

* make feat extractor work

* make feature extraction proj work

* run forward pass

* finish forward pass

* Succesful decoding example

* remove unused files

* more changes

* add wav2vec tokenizer

* add new structure

* fix run forward

* add other layer norm architecture

* finish 2nd structure

* add model tests

* finish tests for tok and model

* clean-up

* make style

* finish docstring for model and config

* make style

* correct docstring

* correct tests

* change checkpoints to fairseq

* fix examples

* finish wav2vec2

* make style

* apply sylvains suggestions

* apply lysandres suggestions

* change print to log.info

* re-add assert statement

* add input_values as required input name

* finish wav2vec2 tokenizer

* Update tests/test_tokenization_wav2vec2.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* apply sylvains suggestions

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen 2021-02-02 15:52:10 +03:00 committed by GitHub
parent d996024af7
commit d6217fb30c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 2233 additions and 5 deletions

View File

@ -228,6 +228,7 @@ ultilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. 1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. 1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. 1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. 1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.

View File

@ -192,16 +192,19 @@ and conversion utilities for the following models:
36. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL: 36. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*, Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
37. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model 37. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli.
38. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau. Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
38. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet: 39. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
39. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised 40. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov. Zettlemoyer and Veselin Stoyanov.
40. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive 41. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
@ -292,6 +295,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
@ -414,6 +419,7 @@ TensorFlow and/or Flax.
model_doc/t5 model_doc/t5
model_doc/tapas model_doc/tapas
model_doc/transformerxl model_doc/transformerxl
model_doc/wav2vec2
model_doc/xlm model_doc/xlm
model_doc/xlmprophetnet model_doc/xlmprophetnet
model_doc/xlmroberta model_doc/xlmroberta

View File

@ -0,0 +1,65 @@
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Wav2Vec2
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Wav2Vec2 model was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
<https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
The abstract from the paper is the following:
*We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on
transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks
the speech input in the latent space and solves a contrastive task defined over a quantization of the latent
representations which are jointly learned. Experiments using all labeled data of Librispeech achieve 1.8/3.3 WER on the
clean/other test sets. When lowering the amount of labeled data to one hour, wav2vec 2.0 outperforms the previous state
of the art on the 100 hour subset while using 100 times less labeled data. Using just ten minutes of labeled data and
pre-training on 53k hours of unlabeled data still achieves 4.8/8.2 WER. This demonstrates the feasibility of speech
recognition with limited amounts of labeled data.*
Tips:
- Wav2Vec2 is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
- Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be decoded
using :class:`~transformers.Wav2Vec2Tokenizer`.
Wav2Vec2Config
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2Config
:members:
Wav2Vec2Tokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2Tokenizer
:members: __call__, save_vocabulary
Wav2Vec2Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2Model
:members: forward
Wav2Vec2ForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForMaskedLM
:members: forward

View File

@ -123,6 +123,7 @@ _deps = [
"sacremoses", "sacremoses",
"scikit-learn", "scikit-learn",
"sentencepiece==0.1.91", "sentencepiece==0.1.91",
"soundfile",
"sphinx-copybutton", "sphinx-copybutton",
"sphinx-markdown-tables", "sphinx-markdown-tables",
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style. "sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
@ -226,12 +227,14 @@ extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
extras["modelcreation"] = deps_list("cookiecutter") extras["modelcreation"] = deps_list("cookiecutter")
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = ( extras["testing"] = (
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil") deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets")
+ extras["retrieval"] + extras["retrieval"]
+ extras["modelcreation"] + extras["modelcreation"]
+ extras["speech"]
) )
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton") extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
extras["quality"] = deps_list("black", "isort", "flake8") extras["quality"] = deps_list("black", "isort", "flake8")

View File

@ -125,6 +125,7 @@ _import_structure = {
], ],
"models": [], "models": [],
# Models # Models
"models.wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", "Wav2Vec2Tokenizer"],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
"models.auto": [ "models.auto": [
@ -363,6 +364,14 @@ if is_torch_available():
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure # PyTorch models structure
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
)
_import_structure["models.convbert"].extend( _import_structure["models.convbert"].extend(
[ [
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -1312,6 +1321,7 @@ if TYPE_CHECKING:
TransfoXLCorpus, TransfoXLCorpus,
TransfoXLTokenizer, TransfoXLTokenizer,
) )
from .models.wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config, Wav2Vec2Tokenizer
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
@ -1791,6 +1801,12 @@ if TYPE_CHECKING:
TransfoXLPreTrainedModel, TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl, load_tf_weights_in_transfo_xl,
) )
from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForMaskedLM,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
from .models.xlm import ( from .models.xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice, XLMForMultipleChoice,

View File

@ -36,6 +36,7 @@ deps = {
"sacremoses": "sacremoses", "sacremoses": "sacremoses",
"scikit-learn": "scikit-learn", "scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece==0.1.91", "sentencepiece": "sentencepiece==0.1.91",
"soundfile": "soundfile",
"sphinx-copybutton": "sphinx-copybutton", "sphinx-copybutton": "sphinx-copybutton",
"sphinx-markdown-tables": "sphinx-markdown-tables", "sphinx-markdown-tables": "sphinx-markdown-tables",
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3", "sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",

View File

@ -155,6 +155,14 @@ except importlib_metadata.PackageNotFoundError:
_scatter_available = False _scatter_available = False
_soundfile_available = importlib.util.find_spec("soundfile") is not None
try:
_soundfile_version = importlib_metadata.version("soundfile")
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
except importlib_metadata.PackageNotFoundError:
_soundfile_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers") old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library # New default cache, shared with the Datasets library
@ -311,6 +319,10 @@ def is_sagemaker_distributed_available():
return importlib.util.find_spec("smdistributed") is not None return importlib.util.find_spec("smdistributed") is not None
def is_soundfile_availble():
return _soundfile_available
def torch_only_method(fn): def torch_only_method(fn):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not _torch_available: if not _torch_available:

View File

@ -63,6 +63,7 @@ from . import (
t5, t5,
tapas, tapas,
transfo_xl, transfo_xl,
wav2vec2,
xlm, xlm,
xlm_roberta, xlm_roberta,
xlnet, xlnet,

View File

@ -59,6 +59,7 @@ from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFI
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import ( from ..xlm_prophetnet.configuration_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -72,6 +73,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value) (key, value)
for pretrained_map in [ for pretrained_map in [
# Add archive maps here # Add archive maps here
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -114,6 +116,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CONFIG_MAPPING = OrderedDict( CONFIG_MAPPING = OrderedDict(
[ [
# Add configs here # Add configs here
("wav2vec2", Wav2Vec2Config),
("convbert", ConvBertConfig), ("convbert", ConvBertConfig),
("led", LEDConfig), ("led", LEDConfig),
("blenderbot-small", BlenderbotSmallConfig), ("blenderbot-small", BlenderbotSmallConfig),
@ -162,6 +165,7 @@ CONFIG_MAPPING = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict( MODEL_NAMES_MAPPING = OrderedDict(
[ [
# Add full (and cased) model names here # Add full (and cased) model names here
("wav2vec2", "Wav2Vec2"),
("convbert", "ConvBERT"), ("convbert", "ConvBERT"),
("led", "LED"), ("led", "LED"),
("blenderbot-small", "BlenderbotSmall"), ("blenderbot-small", "BlenderbotSmall"),

View File

@ -205,6 +205,7 @@ from ..tapas.modeling_tapas import (
TapasModel, TapasModel,
) )
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
XLMForQuestionAnsweringSimple, XLMForQuestionAnsweringSimple,
@ -274,6 +275,7 @@ from .configuration_auto import (
T5Config, T5Config,
TapasConfig, TapasConfig,
TransfoXLConfig, TransfoXLConfig,
Wav2Vec2Config,
XLMConfig, XLMConfig,
XLMProphetNetConfig, XLMProphetNetConfig,
XLMRobertaConfig, XLMRobertaConfig,
@ -288,6 +290,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING = OrderedDict( MODEL_MAPPING = OrderedDict(
[ [
# Base model mapping # Base model mapping
(Wav2Vec2Config, Wav2Vec2Model),
(ConvBertConfig, ConvBertModel), (ConvBertConfig, ConvBertModel),
(LEDConfig, LEDModel), (LEDConfig, LEDModel),
(BlenderbotSmallConfig, BlenderbotSmallModel), (BlenderbotSmallConfig, BlenderbotSmallModel),
@ -367,6 +370,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[ [
# Model with LM heads mapping # Model with LM heads mapping
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(ConvBertConfig, ConvBertForMaskedLM), (ConvBertConfig, ConvBertForMaskedLM),
(LEDConfig, LEDForConditionalGeneration), (LEDConfig, LEDForConditionalGeneration),
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
@ -427,6 +431,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(ConvBertConfig, ConvBertForMaskedLM), (ConvBertConfig, ConvBertForMaskedLM),
(LayoutLMConfig, LayoutLMForMaskedLM), (LayoutLMConfig, LayoutLMForMaskedLM),
(DistilBertConfig, DistilBertForMaskedLM), (DistilBertConfig, DistilBertForMaskedLM),

View File

@ -52,6 +52,7 @@ from ..roberta.tokenization_roberta import RobertaTokenizer
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
from ..tapas.tokenization_tapas import TapasTokenizer from ..tapas.tokenization_tapas import TapasTokenizer
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2Tokenizer
from ..xlm.tokenization_xlm import XLMTokenizer from ..xlm.tokenization_xlm import XLMTokenizer
from .configuration_auto import ( from .configuration_auto import (
AlbertConfig, AlbertConfig,
@ -93,6 +94,7 @@ from .configuration_auto import (
T5Config, T5Config,
TapasConfig, TapasConfig,
TransfoXLConfig, TransfoXLConfig,
Wav2Vec2Config,
XLMConfig, XLMConfig,
XLMProphetNetConfig, XLMProphetNetConfig,
XLMRobertaConfig, XLMRobertaConfig,
@ -238,6 +240,7 @@ TOKENIZER_MAPPING = OrderedDict(
(TapasConfig, (TapasTokenizer, None)), (TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)), (LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)), (ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
(Wav2Vec2Config, (Wav2Vec2Tokenizer, None)),
] ]
) )

View File

@ -0,0 +1,66 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
"tokenization_wav2vec2": ["Wav2Vec2Tokenizer"],
}
if is_torch_available():
_import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForMaskedLM",
"Wav2Vec2Model",
"Wav2Vec2PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from .tokenization_wav2vec2 import Wav2Vec2Tokenizer
if is_torch_available():
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForMaskedLM,
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)

View File

@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Wav2Vec2 model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
}
class Wav2Vec2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.Wav2Vec2Model`. It is used to
instantiate an Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2
`facebook/wav2vec2-base-960h <https://huggingface.co/facebook/wav2vec2-base-960h>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 32):
Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.Wav2Vec2Model` or
:class:`~transformers.TFWav2Vec2Model`. Vocabulary size of the model. Defines the different tokens that can
be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.Wav2Vec2Model`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
convolutional layers.
feat_extract_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout probabilitiy for all 1D convolutional layers in feature extractor.
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`.
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
length of `conv_kernel` defines the number of convolutional layers and has to match the the length of
`conv_dim`.
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the 1D convolutional layers have a bias.
num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128):
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
embeddings layer.
num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16):
Number of groups of 1D convolutional positional embeddings layer.
do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
False`` corresponds to applying layer norm after the attention layer.
Example::
>>> from transformers import Wav2Vec2Model, Wav2Vec2Config
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
>>> configuration = Wav2Vec2Config()
>>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
>>> model = Wav2Vec2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "wav2vec2"
def __init__(
self,
vocab_size=32,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
initializer_range=0.02,
layer_norm_eps=1e-5,
feat_extract_norm="group",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(512, 512, 512, 512, 512, 512, 512),
conv_stride=(5, 2, 2, 2, 2, 2, 2),
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
conv_bias=False,
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
do_stable_layer_norm=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = list(conv_dim)
self.conv_stride = list(conv_stride)
self.conv_kernel = list(conv_kernel)
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_feat_extract_layers = len(self.conv_dim)
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
if (
(len(self.conv_stride) != self.num_feat_extract_layers)
or (len(self.conv_kernel) != self.num_feat_extract_layers)
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError(
"Configuration for convolutional layers is incorrect."
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`,"
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)

View File

@ -0,0 +1,162 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Wav2Vec2 checkpoint."""
import argparse
import fairseq
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
MAPPING = {
"post_extract_proj": "wav2vec2.feature_projection.projection",
"encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv",
"self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm",
"fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "wav2vec2.encoder.layer_norm",
"w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm",
"w2v_encoder.proj": "lm_head",
}
def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
hf_shape = getattr(hf_pointer, weight_type).shape
assert (
hf_shape == value.shape
), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}"
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
hf_pointer.weight_g.data = value
elif weight_type == "weight_v":
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
hf_model.wav2vec2.feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
else:
for key, mapped_key in MAPPING.items():
if key in name:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "weight" in name:
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
unused_weights.append(name)
logger.info("Unused weights", unused_weights)
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
name = full_name.split("conv_layers.")[-1]
items = name.split(".")
layer_id = int(items[0])
type_id = int(items[1])
if type_id == 0:
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
assert (
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
unused_weights.append(full_name)
@torch.no_grad()
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None):
"""
Copy/paste/tweak model's weights to transformers design.
"""
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config())
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
)
model = model[0].eval()
recursively_load_weights(model, hf_wav2vec)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
args = parser.parse_args()
convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path)

View File

@ -0,0 +1,731 @@
# coding=utf-8
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Wav2Vec2 model. """
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Wav2Vec2Config"
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-base-960h"
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
]
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class Wav2Vec2LayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.activation(hidden_states)
return hidden_states
class Wav2Vec2GroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
self.out_conv_dim = config.conv_dim[layer_id]
self.conv = nn.Conv1d(
self.in_conv_dim,
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
class Wav2Vec2PositionalConvEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
class Wav2Vec2SamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
return hidden_states
class Wav2Vec2FeatureExtractor(nn.Module):
"""Construct the featurs from raw audio waveform"""
def __init__(self, config):
super().__init__()
if config.feat_extract_norm == "group":
conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
)
self.conv_layers = nn.ModuleList(conv_layers)
def forward(self, input_values):
hidden_states = input_values[:, None]
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
return hidden_states
class Wav2Vec2FeatureProjection(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_extract_dropout)
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.projection(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2
class Wav2Vec2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
assert attn_weights.size() == (
bsz * self.num_heads,
tgt_len,
src_len,
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
if attention_mask is not None:
assert attention_mask.size() == (
bsz,
1,
tgt_len,
src_len,
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_heads,
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
assert attn_output.size() == (
bsz * self.num_heads,
tgt_len,
self.head_dim,
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
attn_output = (
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
.transpose(1, 2)
.reshape(bsz, tgt_len, embed_dim)
)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class Wav2Vec2FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.intermediate_dropout = nn.Dropout(config.hidden_dropout_prob)
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.intermediate_dropout(hidden_states)
hidden_states = self.output_dense(hidden_states)
hidden_states = self.output_dropout(hidden_states)
return hidden_states
class Wav2Vec2Output(nn.Module):
def __init__(self, config):
super().__init__()
def forward(self, hidden_states, input_tensor):
return hidden_states
class Wav2Vec2EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.hidden_dropout_prob,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, output_attentions=False):
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, attn_weights
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.hidden_dropout_prob,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states, output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
return hidden_states, attn_weights
class Wav2Vec2Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# IMPORTANT: the param for dropout is probs wrong
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
if output_attentions:
all_self_attentions = all_self_attentions + (attn_weights,)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class Wav2Vec2EncoderStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# IMPORTANT: the param for dropout is probs wrong
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layers = nn.ModuleList(
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
hidden_states,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
if output_attentions:
all_self_attentions = all_self_attentions + (attn_weights,)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class Wav2Vec2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
WAV_2_VEC_2_START_DOCSTRING = r"""
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
<https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving etc.).
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config (:class:`~transformers.Wav2Vec2Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
Args:
input_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
self.feature_projection = Wav2Vec2FeatureProjection(config)
if config.do_stable_layer_norm:
self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
else:
self.encoder = Wav2Vec2Encoder(config)
self.init_weights()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
"""
Returns:
Example::
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> hidden_states = model(input_values).last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.feature_extractor(input_values)
hidden_states = self.feature_projection(hidden_states)
encoder_outputs = self.encoder(
hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if not return_dict:
return (hidden_states,) + encoder_outputs[1:]
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
TODO(PVP): Fill out when adding training
Returns:
Example::
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = tokenizer.decode(predicted_ids[0])
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

View File

@ -0,0 +1,282 @@
# coding=utf-8
# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization class for Wav2Vec2."""
import json
import os
from itertools import groupby
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from ...file_utils import add_end_docstrings
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType
from ...utils import logging
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_config_file": "tokenizer_config.json",
}
WAV2VEC2_KWARGS_DOCSTRING = r"""
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Controls the maximum length to use by one of the truncation/padding parameters.
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to print more information and warnings.
"""
class Wav2Vec2Tokenizer(PreTrainedTokenizer):
"""
Constructs a Wav2Vec2 tokenizer.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains some of the main methods.
Users should refer to the superclass for more information regarding such methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
The beginning of sentence token.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sentence token.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
The token used for defining the end of a word.
**kwargs
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = {
"vocab_file": {
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
},
"tokenizer_config_file": {
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
},
}
model_input_names = ["input_values"]
def __init__(
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
pad_token="<pad>",
word_delimiter_token="|",
do_lower_case=False,
**kwargs
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
do_lower_case=do_lower_case,
word_delimiter_token=word_delimiter_token,
**kwargs,
)
self._word_delimiter_token = word_delimiter_token
self.do_lower_case = do_lower_case
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
@property
def word_delimiter_token(self) -> str:
"""
:obj:`str`: Padding token. Log an error if used while not having been set.
"""
if self._word_delimiter_token is None and self.verbose:
logger.error("Using word_delimiter_token, but it is not set yet.")
return None
return str(self._word_delimiter_token)
@property
def word_delimiter_token_id(self) -> Optional[int]:
"""
:obj:`Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns :obj:`None` if the token has
not been set.
"""
if self._word_delimiter_token is None:
return None
return self.convert_tokens_to_ids(self.word_delimiter_token)
@word_delimiter_token.setter
def word_delimiter_token(self, value):
self._word_delimiter_token = value
@word_delimiter_token_id.setter
def word_delimiter_token_id(self, value):
self._word_delimiter_token = self.convert_tokens_to_ids(value)
@add_end_docstrings(WAV2VEC2_KWARGS_DOCSTRING)
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = False,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
**kwargs
) -> BatchEncoding:
"""
Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
sequences.
Args:
raw_speech (:obj:`np.ndarray`, :obj:`List[float]`, :obj:`List[np.ndarray]`, :obj:`List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrayr or a list of list of float values.
"""
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
# always return batch
if not is_batched:
raw_speech = [raw_speech]
# convert into correct format for padding
encoded_inputs = BatchEncoding({"input_values": raw_speech})
padded_inputs = self.pad(
encoded_inputs,
padding=padding,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
return_tensors=return_tensors,
verbose=verbose,
)
return padded_inputs
@property
def vocab_size(self) -> int:
return len(self.decoder)
def get_vocab(self) -> Dict:
return dict(self.encoder, **self.added_tokens_encoder)
def _convert_token_to_id(self, token: str) -> int:
"""Converts a token (str) in an index (integer) using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the vocab."""
result = self.decoder.get(index, self.unk_token)
return result
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
"""
# group same tokens into non-repeating tokens in CTC style decoding
grouped_tokens = [token_group[0] for token_group in groupby(tokens)]
# filter self.pad_token which is used as CTC-blank token
filtered_tokens = list(filter(lambda token: token != self.pad_token, grouped_tokens))
# replace delimiter token
string = "".join([" " if token == self.word_delimiter_token else token for token in filtered_tokens]).strip()
if self.do_lower_case:
string = string.lower()
return string
def _decode(
self,
token_ids: List[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = True,
) -> str:
"""
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on
the whole token list and not individually on added tokens
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
result = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
result.append(token)
text = self.convert_tokens_to_string(result)
if clean_up_tokenization_spaces:
clean_text = self.clean_up_tokenization(text)
return clean_text
else:
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
return (vocab_file,)

View File

@ -31,6 +31,7 @@ from .file_utils import (
is_pandas_available, is_pandas_available,
is_scatter_available, is_scatter_available,
is_sentencepiece_available, is_sentencepiece_available,
is_soundfile_availble,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
@ -367,6 +368,19 @@ def require_ray(test_case):
return test_case return test_case
def require_soundfile(test_case):
"""
Decorator marking a test that requires soundfile
These tests are skipped when soundfile isn't installed.
"""
if not is_soundfile_availble():
return unittest.skip("test requires soundfile")(test_case)
else:
return test_case
def get_gpu_count(): def get_gpu_count():
""" """
Return the number of available gpus (regardless of whether torch or tf is used) Return the number of available gpus (regardless of whether torch or tf is used)

View File

@ -2196,6 +2196,36 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
requires_pytorch(load_tf_weights_in_transfo_xl) requires_pytorch(load_tf_weights_in_transfo_xl)
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Wav2Vec2ForMaskedLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class Wav2Vec2Model:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class Wav2Vec2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -0,0 +1,354 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """
import math
import unittest
from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
class Wav2Vec2ModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=1024, # speech is longer
is_training=False,
hidden_size=16,
feat_extract_norm="group",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(32, 32, 32),
conv_stride=(4, 4, 4),
conv_kernel=(8, 8, 8),
conv_bias=False,
num_conv_pos_embeddings=16,
num_conv_pos_embedding_groups=2,
num_hidden_layers=4,
num_attention_heads=2,
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
intermediate_size=20,
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
vocab_size=32,
do_stable_layer_norm=False,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = conv_dim
self.conv_stride = conv_stride
self.conv_kernel = conv_kernel
self.conv_bias = conv_bias
self.num_conv_pos_embeddings = num_conv_pos_embeddings
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.intermediate_size = intermediate_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.scope = scope
output_seq_length = self.seq_length
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
output_seq_length = (output_seq_length - (kernel - 1)) / stride
self.output_seq_length = int(math.ceil(output_seq_length))
self.encoder_seq_length = self.output_seq_length
def prepare_config_and_inputs(self):
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = Wav2Vec2Config(
hidden_size=self.hidden_size,
feat_extract_norm=self.feat_extract_norm,
feat_extract_dropout=self.feat_extract_dropout,
feat_extract_activation=self.feat_extract_activation,
conv_dim=self.conv_dim,
conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,
intermediate_size=self.intermediate_size,
layer_norm_eps=self.layer_norm_eps,
hidden_act=self.hidden_act,
initializer_range=self.initializer_range,
vocab_size=self.vocab_size,
)
return config, input_values
def create_and_check_model(self, config, input_values):
model = Wav2Vec2Model(config=config)
model.to(torch_device)
model.eval()
result = model(input_values)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def prepare_config_and_inputs_for_common(self):
config, input_values = self.prepare_config_and_inputs()
inputs_dict = {"input_values": input_values}
return config, inputs_dict
@require_torch
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
Wav2Vec2Model,
Wav2Vec2ForMaskedLM,
)
if is_torch_available()
else ()
)
test_pruning = False
test_headmasking = False
test_torchscript = False
def setUp(self):
self.model_tester = Wav2Vec2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
# `input_ids` is renamed to `input_values`
def test_forward_signature(self):
pass
# Wav2Vec2 cannot resize token embeddings
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "conv.weight" in name:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
def setUp(self):
self.model_tester = Wav2Vec2ModelTester(
self, conv_stride=(3, 3, 3), feat_extract_norm="layer", do_stable_layer_norm=True
)
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Wav2Vec2 has no inputs_embeds
def test_inputs_embeds(self):
pass
# `input_ids` is renamed to `input_values`
def test_forward_signature(self):
pass
# Wav2Vec2 cannot resize token embeddings
# since it has no tokens embeddings
def test_resize_tokens_embeddings(self):
pass
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
pass
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
if "conv.weight" in name:
self.assertTrue(
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)
else:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
)
@slow
def test_model_from_pretrained(self):
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsNotNone(model)
@require_torch
@slow
@require_datasets
@require_soundfile
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def _load_datasamples(self, num_samples):
from datasets import load_dataset
import soundfile as sf
# map files to raw
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.select(range(num_samples)).map(map_to_array)
return ds["speech"][:num_samples]
def test_inference_masked_lm_normal(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(1)
input_values = tokenizer(input_speech, return_tensors="pt").input_values.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_normal_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
torch_device
)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_robust_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
torch_device
)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
"his instant panic was followed by a small sharp blow high on his chest",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)

View File

@ -0,0 +1,301 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Wav2Vec2 tokenizer."""
import inspect
import json
import os
import random
import shutil
import tempfile
import unittest
import numpy as np
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
global_rng = random.Random()
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
class Wav2Vec2TokenizerTest(unittest.TestCase):
tokenizer_class = Wav2Vec2Tokenizer
def setUp(self):
super().setUp()
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
vocab_tokens = dict(zip(vocab, range(len(vocab))))
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
self.tmpdirname = tempfile.mkdtemp()
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return Wav2Vec2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def test_tokenizer_decode(self):
# TODO(PVP) - change to facebook
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
sample_ids = [
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
]
tokens = tokenizer.decode(sample_ids[0])
batch_tokens = tokenizer.batch_decode(sample_ids)
self.assertEqual(tokens, batch_tokens[0])
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
def test_tokenizer_decode_special(self):
# TODO(PVP) - change to facebook
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
sample_ids = [
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
]
sample_ids_2 = [
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
[
24,
22,
5,
tokenizer.pad_token_id,
tokenizer.pad_token_id,
tokenizer.pad_token_id,
tokenizer.word_delimiter_token_id,
24,
22,
5,
77,
tokenizer.word_delimiter_token_id,
],
]
batch_tokens = tokenizer.batch_decode(sample_ids)
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
self.assertEqual(batch_tokens, batch_tokens_2)
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
def test_tokenizer_decode_added_tokens(self):
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer.add_tokens(["!", "?"])
tokenizer.add_special_tokens({"cls_token": "$$$"})
sample_ids = [
[
11,
5,
15,
tokenizer.pad_token_id,
15,
8,
98,
32,
32,
33,
tokenizer.word_delimiter_token_id,
32,
32,
33,
34,
34,
],
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
]
batch_tokens = tokenizer.batch_decode(sample_ids)
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
tokenizer = self.get_tokenizer()
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test not batched input
encoded_sequences_1 = tokenizer(speech_inputs[0], return_tensors="np").input_values
encoded_sequences_2 = tokenizer(np_speech_inputs[0], return_tensors="np").input_values
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = tokenizer(speech_inputs, return_tensors="np").input_values
encoded_sequences_2 = tokenizer(np_speech_inputs, return_tensors="np").input_values
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_padding(self, max_length=50):
def _input_values_have_equal_length(input_values):
length = len(input_values[0])
for input_values_slice in input_values[1:]:
if len(input_values_slice) != length:
return False
return True
def _input_values_are_equal(input_values_1, input_values_2):
if len(input_values_1) != len(input_values_2):
return False
for input_values_slice_1, input_values_slice_2 in zip(input_values_1, input_values_2):
if not np.allclose(np.asarray(input_values_slice_1), np.asarray(input_values_slice_2), atol=1e-3):
return False
return True
tokenizer = self.get_tokenizer()
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
input_values_1 = tokenizer(speech_inputs).input_values
input_values_2 = tokenizer(speech_inputs, padding="longest").input_values
input_values_3 = tokenizer(speech_inputs, padding="longest", max_length=1600).input_values
self.assertFalse(_input_values_have_equal_length(input_values_1))
self.assertTrue(_input_values_have_equal_length(input_values_2))
self.assertTrue(_input_values_have_equal_length(input_values_3))
self.assertTrue(_input_values_are_equal(input_values_2, input_values_3))
self.assertTrue(len(input_values_1[0]) == 800)
self.assertTrue(len(input_values_2[0]) == 1200)
# padding should be 0.0
self.assertTrue(abs(sum(np.asarray(input_values_2[0])[800:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_2[1])[1000:])) < 1e-3)
input_values_4 = tokenizer(speech_inputs, padding="max_length").input_values
input_values_5 = tokenizer(speech_inputs, padding="max_length", max_length=1600).input_values
self.assertTrue(_input_values_are_equal(input_values_1, input_values_4))
self.assertTrue(input_values_5.shape, (3, 1600))
# padding should be 0.0
self.assertTrue(abs(sum(np.asarray(input_values_5[0])[800:1200])) < 1e-3)
input_values_6 = tokenizer(speech_inputs, pad_to_multiple_of=500).input_values
input_values_7 = tokenizer(speech_inputs, padding="longest", pad_to_multiple_of=500).input_values
input_values_8 = tokenizer(
speech_inputs, padding="max_length", pad_to_multiple_of=500, max_length=2400
).input_values
self.assertTrue(_input_values_are_equal(input_values_1, input_values_6))
self.assertTrue(input_values_7.shape, (3, 1500))
self.assertTrue(input_values_8.shape, (3, 2500))
# padding should be 0.0
self.assertTrue(abs(sum(np.asarray(input_values_7[0])[800:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_7[1])[1000:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_7[2])[1200:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_8[0])[800:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_8[1])[1000:])) < 1e-3)
self.assertTrue(abs(sum(np.asarray(input_values_8[2])[1200:])) < 1e-3)
def test_save_pretrained(self):
pretrained_name = list(self.tokenizer_class.pretrained_vocab_files_map["vocab_file"].keys())[0]
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name)
tmpdirname2 = tempfile.mkdtemp()
tokenizer_files = tokenizer.save_pretrained(tmpdirname2)
self.assertSequenceEqual(
sorted(tuple(VOCAB_FILES_NAMES.values()) + ("special_tokens_map.json", "added_tokens.json")),
sorted(tuple(x.split("/")[-1] for x in tokenizer_files)),
)
# Checks everything loads correctly in the same way
tokenizer_p = self.tokenizer_class.from_pretrained(tmpdirname2)
# Check special tokens are set accordingly on Rust and Python
for key in tokenizer.special_tokens_map:
self.assertTrue(key in tokenizer_p.special_tokens_map)
shutil.rmtree(tmpdirname2)
def test_get_vocab(self):
tokenizer = self.get_tokenizer()
vocab_dict = tokenizer.get_vocab()
self.assertIsInstance(vocab_dict, dict)
self.assertGreaterEqual(len(tokenizer), len(vocab_dict))
vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
self.assertEqual(len(vocab), len(tokenizer))
tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
self.assertEqual(len(vocab), len(tokenizer))
def test_save_and_load_tokenizer(self):
tokenizer = self.get_tokenizer()
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
sample_ids = [0, 1, 4, 8, 9, 0, 12]
before_tokens = tokenizer.decode(sample_ids)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.decode(sample_ids)
after_vocab = after_tokenizer.get_vocab()
self.assertEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
shutil.rmtree(tmpdirname)
tokenizer = self.get_tokenizer()
# Isolate this from the other tests because we save additional tokens/etc
tmpdirname = tempfile.mkdtemp()
before_len = len(tokenizer)
sample_ids = [0, 1, 4, 8, 9, 0, 12, before_len, before_len + 1, before_len + 2]
tokenizer.add_tokens(["?", "!"])
additional_special_tokens = tokenizer.additional_special_tokens
additional_special_tokens.append("&")
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
before_tokens = tokenizer.decode(sample_ids)
before_vocab = tokenizer.get_vocab()
tokenizer.save_pretrained(tmpdirname)
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
after_tokens = after_tokenizer.decode(sample_ids)
after_vocab = after_tokenizer.get_vocab()
self.assertEqual(before_tokens, after_tokens)
self.assertDictEqual(before_vocab, after_vocab)
self.assertTrue(len(tokenizer), before_len + 3)
self.assertTrue(len(tokenizer), len(after_tokenizer))
shutil.rmtree(tmpdirname)
def test_tokenizer_slow_store_full_signature(self):
signature = inspect.signature(self.tokenizer_class.__init__)
tokenizer = self.get_tokenizer()
for parameter_name, parameter in signature.parameters.items():
if parameter.default != inspect.Parameter.empty:
self.assertIn(parameter_name, tokenizer.init_kwargs)