mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Wav2Vec2 (#9659)
* add raw scaffold * implement feat extract layers * make style * remove + * correctly convert weights * make feat extractor work * make feature extraction proj work * run forward pass * finish forward pass * Succesful decoding example * remove unused files * more changes * add wav2vec tokenizer * add new structure * fix run forward * add other layer norm architecture * finish 2nd structure * add model tests * finish tests for tok and model * clean-up * make style * finish docstring for model and config * make style * correct docstring * correct tests * change checkpoints to fairseq * fix examples * finish wav2vec2 * make style * apply sylvains suggestions * apply lysandres suggestions * change print to log.info * re-add assert statement * add input_values as required input name * finish wav2vec2 tokenizer * Update tests/test_tokenization_wav2vec2.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * apply sylvains suggestions Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
d996024af7
commit
d6217fb30c
@ -228,6 +228,7 @@ ultilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/
|
|||||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||||
|
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||||
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||||
1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
|
||||||
|
@ -192,16 +192,19 @@ and conversion utilities for the following models:
|
|||||||
36. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
36. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||||
37. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
37. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||||
|
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||||
|
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
38. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||||
38. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
39. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||||
39. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
40. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||||
Zettlemoyer and Veselin Stoyanov.
|
Zettlemoyer and Veselin Stoyanov.
|
||||||
40. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
41. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||||
|
|
||||||
@ -292,6 +295,8 @@ TensorFlow and/or Flax.
|
|||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
|
| Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
@ -414,6 +419,7 @@ TensorFlow and/or Flax.
|
|||||||
model_doc/t5
|
model_doc/t5
|
||||||
model_doc/tapas
|
model_doc/tapas
|
||||||
model_doc/transformerxl
|
model_doc/transformerxl
|
||||||
|
model_doc/wav2vec2
|
||||||
model_doc/xlm
|
model_doc/xlm
|
||||||
model_doc/xlmprophetnet
|
model_doc/xlmprophetnet
|
||||||
model_doc/xlmroberta
|
model_doc/xlmroberta
|
||||||
|
65
docs/source/model_doc/wav2vec2.rst
Normal file
65
docs/source/model_doc/wav2vec2.rst
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
..
|
||||||
|
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||||
|
the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
Wav2Vec2
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Overview
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The Wav2Vec2 model was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||||
|
<https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on
|
||||||
|
transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks
|
||||||
|
the speech input in the latent space and solves a contrastive task defined over a quantization of the latent
|
||||||
|
representations which are jointly learned. Experiments using all labeled data of Librispeech achieve 1.8/3.3 WER on the
|
||||||
|
clean/other test sets. When lowering the amount of labeled data to one hour, wav2vec 2.0 outperforms the previous state
|
||||||
|
of the art on the 100 hour subset while using 100 times less labeled data. Using just ten minutes of labeled data and
|
||||||
|
pre-training on 53k hours of unlabeled data still achieves 4.8/8.2 WER. This demonstrates the feasibility of speech
|
||||||
|
recognition with limited amounts of labeled data.*
|
||||||
|
|
||||||
|
Tips:
|
||||||
|
|
||||||
|
- Wav2Vec2 is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
|
||||||
|
- Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be decoded
|
||||||
|
using :class:`~transformers.Wav2Vec2Tokenizer`.
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2Config
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Wav2Vec2Config
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2Tokenizer
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Wav2Vec2Tokenizer
|
||||||
|
:members: __call__, save_vocabulary
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2Model
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Wav2Vec2Model
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2ForMaskedLM
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Wav2Vec2ForMaskedLM
|
||||||
|
:members: forward
|
5
setup.py
5
setup.py
@ -123,6 +123,7 @@ _deps = [
|
|||||||
"sacremoses",
|
"sacremoses",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"sentencepiece==0.1.91",
|
"sentencepiece==0.1.91",
|
||||||
|
"soundfile",
|
||||||
"sphinx-copybutton",
|
"sphinx-copybutton",
|
||||||
"sphinx-markdown-tables",
|
"sphinx-markdown-tables",
|
||||||
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
|
||||||
@ -226,12 +227,14 @@ extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
|||||||
extras["modelcreation"] = deps_list("cookiecutter")
|
extras["modelcreation"] = deps_list("cookiecutter")
|
||||||
|
|
||||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
|
||||||
|
extras["speech"] = deps_list("soundfile")
|
||||||
|
|
||||||
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
|
||||||
extras["testing"] = (
|
extras["testing"] = (
|
||||||
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil")
|
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets")
|
||||||
+ extras["retrieval"]
|
+ extras["retrieval"]
|
||||||
+ extras["modelcreation"]
|
+ extras["modelcreation"]
|
||||||
|
+ extras["speech"]
|
||||||
)
|
)
|
||||||
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
|
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
|
||||||
extras["quality"] = deps_list("black", "isort", "flake8")
|
extras["quality"] = deps_list("black", "isort", "flake8")
|
||||||
|
@ -125,6 +125,7 @@ _import_structure = {
|
|||||||
],
|
],
|
||||||
"models": [],
|
"models": [],
|
||||||
# Models
|
# Models
|
||||||
|
"models.wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config", "Wav2Vec2Tokenizer"],
|
||||||
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
|
||||||
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
||||||
"models.auto": [
|
"models.auto": [
|
||||||
@ -363,6 +364,14 @@ if is_torch_available():
|
|||||||
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
|
||||||
# PyTorch models structure
|
# PyTorch models structure
|
||||||
|
|
||||||
|
_import_structure["models.wav2vec2"].extend(
|
||||||
|
[
|
||||||
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Wav2Vec2ForMaskedLM",
|
||||||
|
"Wav2Vec2Model",
|
||||||
|
"Wav2Vec2PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
_import_structure["models.convbert"].extend(
|
_import_structure["models.convbert"].extend(
|
||||||
[
|
[
|
||||||
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -1312,6 +1321,7 @@ if TYPE_CHECKING:
|
|||||||
TransfoXLCorpus,
|
TransfoXLCorpus,
|
||||||
TransfoXLTokenizer,
|
TransfoXLTokenizer,
|
||||||
)
|
)
|
||||||
|
from .models.wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config, Wav2Vec2Tokenizer
|
||||||
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
|
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
|
||||||
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
||||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||||
@ -1791,6 +1801,12 @@ if TYPE_CHECKING:
|
|||||||
TransfoXLPreTrainedModel,
|
TransfoXLPreTrainedModel,
|
||||||
load_tf_weights_in_transfo_xl,
|
load_tf_weights_in_transfo_xl,
|
||||||
)
|
)
|
||||||
|
from .models.wav2vec2 import (
|
||||||
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
Wav2Vec2ForMaskedLM,
|
||||||
|
Wav2Vec2Model,
|
||||||
|
Wav2Vec2PreTrainedModel,
|
||||||
|
)
|
||||||
from .models.xlm import (
|
from .models.xlm import (
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
XLMForMultipleChoice,
|
XLMForMultipleChoice,
|
||||||
|
@ -36,6 +36,7 @@ deps = {
|
|||||||
"sacremoses": "sacremoses",
|
"sacremoses": "sacremoses",
|
||||||
"scikit-learn": "scikit-learn",
|
"scikit-learn": "scikit-learn",
|
||||||
"sentencepiece": "sentencepiece==0.1.91",
|
"sentencepiece": "sentencepiece==0.1.91",
|
||||||
|
"soundfile": "soundfile",
|
||||||
"sphinx-copybutton": "sphinx-copybutton",
|
"sphinx-copybutton": "sphinx-copybutton",
|
||||||
"sphinx-markdown-tables": "sphinx-markdown-tables",
|
"sphinx-markdown-tables": "sphinx-markdown-tables",
|
||||||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||||
|
@ -155,6 +155,14 @@ except importlib_metadata.PackageNotFoundError:
|
|||||||
_scatter_available = False
|
_scatter_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_soundfile_available = importlib.util.find_spec("soundfile") is not None
|
||||||
|
try:
|
||||||
|
_soundfile_version = importlib_metadata.version("soundfile")
|
||||||
|
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_soundfile_available = False
|
||||||
|
|
||||||
|
|
||||||
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
||||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||||
# New default cache, shared with the Datasets library
|
# New default cache, shared with the Datasets library
|
||||||
@ -311,6 +319,10 @@ def is_sagemaker_distributed_available():
|
|||||||
return importlib.util.find_spec("smdistributed") is not None
|
return importlib.util.find_spec("smdistributed") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_soundfile_availble():
|
||||||
|
return _soundfile_available
|
||||||
|
|
||||||
|
|
||||||
def torch_only_method(fn):
|
def torch_only_method(fn):
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if not _torch_available:
|
if not _torch_available:
|
||||||
|
@ -63,6 +63,7 @@ from . import (
|
|||||||
t5,
|
t5,
|
||||||
tapas,
|
tapas,
|
||||||
transfo_xl,
|
transfo_xl,
|
||||||
|
wav2vec2,
|
||||||
xlm,
|
xlm,
|
||||||
xlm_roberta,
|
xlm_roberta,
|
||||||
xlnet,
|
xlnet,
|
||||||
|
@ -59,6 +59,7 @@ from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFI
|
|||||||
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
||||||
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||||
|
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
|
||||||
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||||
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
|
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
|
||||||
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@ -72,6 +73,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||||||
(key, value)
|
(key, value)
|
||||||
for pretrained_map in [
|
for pretrained_map in [
|
||||||
# Add archive maps here
|
# Add archive maps here
|
||||||
|
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
LED_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@ -114,6 +116,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||||||
CONFIG_MAPPING = OrderedDict(
|
CONFIG_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Add configs here
|
# Add configs here
|
||||||
|
("wav2vec2", Wav2Vec2Config),
|
||||||
("convbert", ConvBertConfig),
|
("convbert", ConvBertConfig),
|
||||||
("led", LEDConfig),
|
("led", LEDConfig),
|
||||||
("blenderbot-small", BlenderbotSmallConfig),
|
("blenderbot-small", BlenderbotSmallConfig),
|
||||||
@ -162,6 +165,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||||||
MODEL_NAMES_MAPPING = OrderedDict(
|
MODEL_NAMES_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Add full (and cased) model names here
|
# Add full (and cased) model names here
|
||||||
|
("wav2vec2", "Wav2Vec2"),
|
||||||
("convbert", "ConvBERT"),
|
("convbert", "ConvBERT"),
|
||||||
("led", "LED"),
|
("led", "LED"),
|
||||||
("blenderbot-small", "BlenderbotSmall"),
|
("blenderbot-small", "BlenderbotSmall"),
|
||||||
|
@ -205,6 +205,7 @@ from ..tapas.modeling_tapas import (
|
|||||||
TapasModel,
|
TapasModel,
|
||||||
)
|
)
|
||||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||||
|
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
|
||||||
from ..xlm.modeling_xlm import (
|
from ..xlm.modeling_xlm import (
|
||||||
XLMForMultipleChoice,
|
XLMForMultipleChoice,
|
||||||
XLMForQuestionAnsweringSimple,
|
XLMForQuestionAnsweringSimple,
|
||||||
@ -274,6 +275,7 @@ from .configuration_auto import (
|
|||||||
T5Config,
|
T5Config,
|
||||||
TapasConfig,
|
TapasConfig,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
|
Wav2Vec2Config,
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMProphetNetConfig,
|
XLMProphetNetConfig,
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
@ -288,6 +290,7 @@ logger = logging.get_logger(__name__)
|
|||||||
MODEL_MAPPING = OrderedDict(
|
MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Base model mapping
|
# Base model mapping
|
||||||
|
(Wav2Vec2Config, Wav2Vec2Model),
|
||||||
(ConvBertConfig, ConvBertModel),
|
(ConvBertConfig, ConvBertModel),
|
||||||
(LEDConfig, LEDModel),
|
(LEDConfig, LEDModel),
|
||||||
(BlenderbotSmallConfig, BlenderbotSmallModel),
|
(BlenderbotSmallConfig, BlenderbotSmallModel),
|
||||||
@ -367,6 +370,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||||||
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model with LM heads mapping
|
# Model with LM heads mapping
|
||||||
|
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
|
||||||
(ConvBertConfig, ConvBertForMaskedLM),
|
(ConvBertConfig, ConvBertForMaskedLM),
|
||||||
(LEDConfig, LEDForConditionalGeneration),
|
(LEDConfig, LEDForConditionalGeneration),
|
||||||
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
|
(BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration),
|
||||||
@ -427,6 +431,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
|
|||||||
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
# Model for Masked LM mapping
|
# Model for Masked LM mapping
|
||||||
|
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
|
||||||
(ConvBertConfig, ConvBertForMaskedLM),
|
(ConvBertConfig, ConvBertForMaskedLM),
|
||||||
(LayoutLMConfig, LayoutLMForMaskedLM),
|
(LayoutLMConfig, LayoutLMForMaskedLM),
|
||||||
(DistilBertConfig, DistilBertForMaskedLM),
|
(DistilBertConfig, DistilBertForMaskedLM),
|
||||||
|
@ -52,6 +52,7 @@ from ..roberta.tokenization_roberta import RobertaTokenizer
|
|||||||
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
|
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
|
||||||
from ..tapas.tokenization_tapas import TapasTokenizer
|
from ..tapas.tokenization_tapas import TapasTokenizer
|
||||||
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
|
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
|
||||||
|
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2Tokenizer
|
||||||
from ..xlm.tokenization_xlm import XLMTokenizer
|
from ..xlm.tokenization_xlm import XLMTokenizer
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
@ -93,6 +94,7 @@ from .configuration_auto import (
|
|||||||
T5Config,
|
T5Config,
|
||||||
TapasConfig,
|
TapasConfig,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
|
Wav2Vec2Config,
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMProphetNetConfig,
|
XLMProphetNetConfig,
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
@ -238,6 +240,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||||||
(TapasConfig, (TapasTokenizer, None)),
|
(TapasConfig, (TapasTokenizer, None)),
|
||||||
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
|
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
|
||||||
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
|
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
|
||||||
|
(Wav2Vec2Config, (Wav2Vec2Tokenizer, None)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
66
src/transformers/models/wav2vec2/__init__.py
Normal file
66
src/transformers/models/wav2vec2/__init__.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {
|
||||||
|
"configuration_wav2vec2": ["WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "Wav2Vec2Config"],
|
||||||
|
"tokenization_wav2vec2": ["Wav2Vec2Tokenizer"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
_import_structure["modeling_wav2vec2"] = [
|
||||||
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Wav2Vec2ForMaskedLM",
|
||||||
|
"Wav2Vec2Model",
|
||||||
|
"Wav2Vec2PreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
|
||||||
|
from .tokenization_wav2vec2 import Wav2Vec2Tokenizer
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .modeling_wav2vec2 import (
|
||||||
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
Wav2Vec2ForMaskedLM,
|
||||||
|
Wav2Vec2Model,
|
||||||
|
Wav2Vec2PreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
class _LazyModule(_BaseLazyModule):
|
||||||
|
"""
|
||||||
|
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__file__ = globals()["__file__"]
|
||||||
|
__path__ = [os.path.dirname(__file__)]
|
||||||
|
|
||||||
|
def _get_module(self, module_name: str):
|
||||||
|
return importlib.import_module("." + module_name, self.__name__)
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
171
src/transformers/models/wav2vec2/configuration_wav2vec2.py
Normal file
171
src/transformers/models/wav2vec2/configuration_wav2vec2.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Wav2Vec2 model configuration """
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/config.json",
|
||||||
|
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a :class:`~transformers.Wav2Vec2Model`. It is used to
|
||||||
|
instantiate an Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating
|
||||||
|
a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2
|
||||||
|
`facebook/wav2vec2-base-960h <https://huggingface.co/facebook/wav2vec2-base-960h>`__ architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||||
|
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (:obj:`int`, `optional`, defaults to 32):
|
||||||
|
Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
|
||||||
|
the :obj:`inputs_ids` passed when calling :class:`~transformers.Wav2Vec2Model` or
|
||||||
|
:class:`~transformers.TFWav2Vec2Model`. Vocabulary size of the model. Defines the different tokens that can
|
||||||
|
be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.Wav2Vec2Model`.
|
||||||
|
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||||
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
|
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||||
|
Number of hidden layers in the Transformer encoder.
|
||||||
|
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||||
|
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||||
|
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||||
|
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||||
|
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
|
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||||
|
The epsilon used by the layer normalization layers.
|
||||||
|
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
|
||||||
|
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
|
||||||
|
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
|
||||||
|
convolutional layers.
|
||||||
|
feat_extract_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||||
|
The dropout probabilitiy for all 1D convolutional layers in feature extractor.
|
||||||
|
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
||||||
|
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||||
|
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
|
||||||
|
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
||||||
|
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
|
||||||
|
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
|
||||||
|
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
|
||||||
|
of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`.
|
||||||
|
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
|
||||||
|
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
|
||||||
|
length of `conv_kernel` defines the number of convolutional layers and has to match the the length of
|
||||||
|
`conv_dim`.
|
||||||
|
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether the 1D convolutional layers have a bias.
|
||||||
|
num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128):
|
||||||
|
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
||||||
|
embeddings layer.
|
||||||
|
num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16):
|
||||||
|
Number of groups of 1D convolutional positional embeddings layer.
|
||||||
|
do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
|
||||||
|
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
|
||||||
|
False`` corresponds to applying layer norm after the attention layer.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import Wav2Vec2Model, Wav2Vec2Config
|
||||||
|
|
||||||
|
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
|
||||||
|
>>> configuration = Wav2Vec2Config()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the facebook/wav2vec2-base-960h style configuration
|
||||||
|
>>> model = Wav2Vec2Model(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
"""
|
||||||
|
model_type = "wav2vec2"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32,
|
||||||
|
hidden_size=768,
|
||||||
|
num_hidden_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
intermediate_size=3072,
|
||||||
|
hidden_act="gelu",
|
||||||
|
hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
|
||||||
|
attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
|
||||||
|
initializer_range=0.02,
|
||||||
|
layer_norm_eps=1e-5,
|
||||||
|
feat_extract_norm="group",
|
||||||
|
feat_extract_dropout=0.0,
|
||||||
|
feat_extract_activation="gelu",
|
||||||
|
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
||||||
|
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
||||||
|
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
||||||
|
conv_bias=False,
|
||||||
|
num_conv_pos_embeddings=128,
|
||||||
|
num_conv_pos_embedding_groups=16,
|
||||||
|
do_stable_layer_norm=False,
|
||||||
|
pad_token_id=0,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.feat_extract_norm = feat_extract_norm
|
||||||
|
self.feat_extract_dropout = feat_extract_dropout
|
||||||
|
self.feat_extract_activation = feat_extract_activation
|
||||||
|
self.conv_dim = list(conv_dim)
|
||||||
|
self.conv_stride = list(conv_stride)
|
||||||
|
self.conv_kernel = list(conv_kernel)
|
||||||
|
self.conv_bias = conv_bias
|
||||||
|
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||||
|
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||||
|
self.num_feat_extract_layers = len(self.conv_dim)
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
|
|
||||||
|
if (
|
||||||
|
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||||
|
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
||||||
|
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Configuration for convolutional layers is incorrect."
|
||||||
|
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`,"
|
||||||
|
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
|
||||||
|
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
||||||
|
)
|
@ -0,0 +1,162 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert Wav2Vec2 checkpoint."""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import fairseq
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
MAPPING = {
|
||||||
|
"post_extract_proj": "wav2vec2.feature_projection.projection",
|
||||||
|
"encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv",
|
||||||
|
"self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj",
|
||||||
|
"self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj",
|
||||||
|
"self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj",
|
||||||
|
"self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj",
|
||||||
|
"self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm",
|
||||||
|
"fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense",
|
||||||
|
"fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense",
|
||||||
|
"final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm",
|
||||||
|
"encoder.layer_norm": "wav2vec2.encoder.layer_norm",
|
||||||
|
"w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm",
|
||||||
|
"w2v_encoder.proj": "lm_head",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||||
|
for attribute in key.split("."):
|
||||||
|
hf_pointer = getattr(hf_pointer, attribute)
|
||||||
|
|
||||||
|
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||||
|
assert (
|
||||||
|
hf_shape == value.shape
|
||||||
|
), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||||
|
if weight_type == "weight":
|
||||||
|
hf_pointer.weight.data = value
|
||||||
|
elif weight_type == "weight_g":
|
||||||
|
hf_pointer.weight_g.data = value
|
||||||
|
elif weight_type == "weight_v":
|
||||||
|
hf_pointer.weight_v.data = value
|
||||||
|
elif weight_type == "bias":
|
||||||
|
hf_pointer.bias.data = value
|
||||||
|
logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.")
|
||||||
|
|
||||||
|
|
||||||
|
def recursively_load_weights(fairseq_model, hf_model):
|
||||||
|
unused_weights = []
|
||||||
|
fairseq_dict = fairseq_model.state_dict()
|
||||||
|
|
||||||
|
for name, value in fairseq_dict.items():
|
||||||
|
is_used = False
|
||||||
|
if "conv_layers" in name:
|
||||||
|
load_conv_layer(
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
hf_model.wav2vec2.feature_extractor,
|
||||||
|
unused_weights,
|
||||||
|
hf_model.config.feat_extract_norm == "group",
|
||||||
|
)
|
||||||
|
is_used = True
|
||||||
|
else:
|
||||||
|
for key, mapped_key in MAPPING.items():
|
||||||
|
if key in name:
|
||||||
|
is_used = True
|
||||||
|
if "*" in mapped_key:
|
||||||
|
layer_index = name.split(key)[0].split(".")[-2]
|
||||||
|
mapped_key = mapped_key.replace("*", layer_index)
|
||||||
|
if "weight_g" in name:
|
||||||
|
weight_type = "weight_g"
|
||||||
|
elif "weight_v" in name:
|
||||||
|
weight_type = "weight_v"
|
||||||
|
elif "weight" in name:
|
||||||
|
weight_type = "weight"
|
||||||
|
elif "bias" in name:
|
||||||
|
weight_type = "bias"
|
||||||
|
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||||
|
continue
|
||||||
|
if not is_used:
|
||||||
|
unused_weights.append(name)
|
||||||
|
|
||||||
|
logger.info("Unused weights", unused_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
||||||
|
name = full_name.split("conv_layers.")[-1]
|
||||||
|
items = name.split(".")
|
||||||
|
layer_id = int(items[0])
|
||||||
|
type_id = int(items[1])
|
||||||
|
|
||||||
|
if type_id == 0:
|
||||||
|
if "bias" in name:
|
||||||
|
assert (
|
||||||
|
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
|
||||||
|
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
||||||
|
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
||||||
|
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||||
|
elif "weight" in name:
|
||||||
|
assert (
|
||||||
|
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
|
||||||
|
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
||||||
|
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
||||||
|
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||||
|
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
||||||
|
if "bias" in name:
|
||||||
|
assert (
|
||||||
|
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
|
||||||
|
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
|
||||||
|
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
||||||
|
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||||
|
elif "weight" in name:
|
||||||
|
assert (
|
||||||
|
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
|
||||||
|
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
||||||
|
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
||||||
|
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||||
|
else:
|
||||||
|
unused_weights.append(full_name)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None):
|
||||||
|
"""
|
||||||
|
Copy/paste/tweak model's weights to transformers design.
|
||||||
|
"""
|
||||||
|
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config())
|
||||||
|
|
||||||
|
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||||
|
)
|
||||||
|
model = model[0].eval()
|
||||||
|
|
||||||
|
recursively_load_weights(model, hf_wav2vec)
|
||||||
|
|
||||||
|
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||||
|
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||||
|
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path)
|
731
src/transformers/models/wav2vec2/modeling_wav2vec2.py
Executable file
731
src/transformers/models/wav2vec2/modeling_wav2vec2.py
Executable file
@ -0,0 +1,731 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Wav2Vec2 model. """
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
|
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import logging
|
||||||
|
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||||
|
|
||||||
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
|
"facebook/wav2vec2-base-960h"
|
||||||
|
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_id=0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
|
||||||
|
self.out_conv_dim = config.conv_dim[layer_id]
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
self.in_conv_dim,
|
||||||
|
self.out_conv_dim,
|
||||||
|
kernel_size=config.conv_kernel[layer_id],
|
||||||
|
stride=config.conv_stride[layer_id],
|
||||||
|
bias=config.conv_bias,
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2LayerNormConvLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_id=0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
|
||||||
|
self.out_conv_dim = config.conv_dim[layer_id]
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
self.in_conv_dim,
|
||||||
|
self.out_conv_dim,
|
||||||
|
kernel_size=config.conv_kernel[layer_id],
|
||||||
|
stride=config.conv_stride[layer_id],
|
||||||
|
bias=config.conv_bias,
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||||
|
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
||||||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(-2, -1)
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = hidden_states.transpose(-2, -1)
|
||||||
|
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2GroupNormConvLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_id=0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
|
||||||
|
self.out_conv_dim = config.conv_dim[layer_id]
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
self.in_conv_dim,
|
||||||
|
self.out_conv_dim,
|
||||||
|
kernel_size=config.conv_kernel[layer_id],
|
||||||
|
stride=config.conv_stride[layer_id],
|
||||||
|
bias=config.conv_bias,
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||||||
|
|
||||||
|
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2PositionalConvEmbedding(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
kernel_size=config.num_conv_pos_embeddings,
|
||||||
|
padding=config.num_conv_pos_embeddings // 2,
|
||||||
|
groups=config.num_conv_pos_embedding_groups,
|
||||||
|
)
|
||||||
|
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
|
||||||
|
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
|
||||||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
|
||||||
|
hidden_states = self.conv(hidden_states)
|
||||||
|
hidden_states = self.padding(hidden_states)
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2SamePadLayer(nn.Module):
|
||||||
|
def __init__(self, num_conv_pos_embeddings):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
if self.num_pad_remove > 0:
|
||||||
|
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2FeatureExtractor(nn.Module):
|
||||||
|
"""Construct the featurs from raw audio waveform"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if config.feat_extract_norm == "group":
|
||||||
|
conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
|
||||||
|
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
|
||||||
|
]
|
||||||
|
elif config.feat_extract_norm == "layer":
|
||||||
|
conv_layers = [
|
||||||
|
Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
||||||
|
)
|
||||||
|
self.conv_layers = nn.ModuleList(conv_layers)
|
||||||
|
|
||||||
|
def forward(self, input_values):
|
||||||
|
hidden_states = input_values[:, None]
|
||||||
|
for conv_layer in self.conv_layers:
|
||||||
|
hidden_states = conv_layer(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2FeatureProjection(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
||||||
|
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = self.projection(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2
|
||||||
|
class Wav2Vec2Attention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
is_decoder: bool = False,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert (
|
||||||
|
self.head_dim * num_heads == self.embed_dim
|
||||||
|
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
|
# get key, value proj
|
||||||
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0]
|
||||||
|
value_states = past_key_value[1]
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
|
key_states = key_states.view(*proj_shape)
|
||||||
|
value_states = value_states.view(*proj_shape)
|
||||||
|
|
||||||
|
src_len = key_states.size(1)
|
||||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
|
assert attn_weights.size() == (
|
||||||
|
bsz * self.num_heads,
|
||||||
|
tgt_len,
|
||||||
|
src_len,
|
||||||
|
), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
assert attention_mask.size() == (
|
||||||
|
bsz,
|
||||||
|
1,
|
||||||
|
tgt_len,
|
||||||
|
src_len,
|
||||||
|
), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
assert layer_head_mask.size() == (
|
||||||
|
self.num_heads,
|
||||||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
# this operation is a bit akward, but it's required to
|
||||||
|
# make sure that attn_weights keeps its gradient.
|
||||||
|
# In order to do so, attn_weights have to reshaped
|
||||||
|
# twice and have to be reused in the following
|
||||||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
else:
|
||||||
|
attn_weights_reshaped = None
|
||||||
|
|
||||||
|
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
|
assert attn_output.size() == (
|
||||||
|
bsz * self.num_heads,
|
||||||
|
tgt_len,
|
||||||
|
self.head_dim,
|
||||||
|
), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
||||||
|
|
||||||
|
attn_output = (
|
||||||
|
attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(bsz, tgt_len, embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2FeedForward(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.intermediate_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
|
if isinstance(config.hidden_act, str):
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
|
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = self.intermediate_dense(hidden_states)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
hidden_states = self.intermediate_dropout(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.output_dense(hidden_states)
|
||||||
|
hidden_states = self.output_dropout(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Output(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, hidden_states, input_tensor):
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Wav2Vec2Attention(
|
||||||
|
embed_dim=config.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
dropout=config.hidden_dropout_prob,
|
||||||
|
is_decoder=False,
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, output_attentions=False):
|
||||||
|
attn_residual = hidden_states
|
||||||
|
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = attn_residual + hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = Wav2Vec2Attention(
|
||||||
|
embed_dim=config.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
dropout=config.hidden_dropout_prob,
|
||||||
|
is_decoder=False,
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, output_attentions=False):
|
||||||
|
attn_residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = attn_residual + hidden_states
|
||||||
|
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||||
|
|
||||||
|
return hidden_states, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Encoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
# IMPORTANT: the param for dropout is probs wrong
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
):
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||||
|
hidden_states = hidden_states + position_embeddings
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
# IMPORTANT: the param for dropout is probs wrong
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
):
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||||
|
hidden_states = hidden_states + position_embeddings
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = Wav2Vec2Config
|
||||||
|
base_model_prefix = "wav2vec2"
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
""" Initialize the weights """
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
elif isinstance(module, nn.Conv1d):
|
||||||
|
torch.nn.init.kaiming_normal_(module.weight.data)
|
||||||
|
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||||
|
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||||
|
<https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||||
|
|
||||||
|
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||||
|
methods the library implements for all its model (such as downloading or saving etc.).
|
||||||
|
|
||||||
|
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
|
||||||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||||
|
behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.Wav2Vec2Config`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||||||
|
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||||
|
weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
||||||
|
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
||||||
|
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
|
||||||
|
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||||
|
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
WAV_2_VEC_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
|
||||||
|
self.feature_projection = Wav2Vec2FeatureProjection(config)
|
||||||
|
|
||||||
|
if config.do_stable_layer_norm:
|
||||||
|
self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
|
||||||
|
else:
|
||||||
|
self.encoder = Wav2Vec2Encoder(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import soundfile as sf
|
||||||
|
|
||||||
|
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
>>> model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
>>> speech, _ = sf.read(batch["file"])
|
||||||
|
>>> batch["speech"] = speech
|
||||||
|
>>> return batch
|
||||||
|
|
||||||
|
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> ds = ds.map(map_to_array)
|
||||||
|
|
||||||
|
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||||
|
>>> hidden_states = model(input_values).last_hidden_state
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
hidden_states = self.feature_extractor(input_values)
|
||||||
|
hidden_states = self.feature_projection(hidden_states)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (hidden_states,) + encoder_outputs[1:]
|
||||||
|
|
||||||
|
return BaseModelOutput(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
|
attentions=encoder_outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top. """, WAV_2_VEC_2_START_DOCSTRING)
|
||||||
|
class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
TODO(PVP): Fill out when adding training
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import soundfile as sf
|
||||||
|
|
||||||
|
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
>>> model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
>>> speech, _ = sf.read(batch["file"])
|
||||||
|
>>> batch["speech"] = speech
|
||||||
|
>>> return batch
|
||||||
|
|
||||||
|
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> ds = ds.map(map_to_array)
|
||||||
|
|
||||||
|
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||||
|
>>> logits = model(input_values).logits
|
||||||
|
|
||||||
|
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
>>> transcription = tokenizer.decode(predicted_ids[0])
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.wav2vec2(
|
||||||
|
input_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return output
|
||||||
|
|
||||||
|
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
282
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
Normal file
282
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tokenization class for Wav2Vec2."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from itertools import groupby
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ...file_utils import add_end_docstrings
|
||||||
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
|
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
VOCAB_FILES_NAMES = {
|
||||||
|
"vocab_file": "vocab.json",
|
||||||
|
"tokenizer_config_file": "tokenizer_config.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
WAV2VEC2_KWARGS_DOCSTRING = r"""
|
||||||
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||||
|
Activates and controls padding. Accepts the following values:
|
||||||
|
|
||||||
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||||
|
single sequence if provided).
|
||||||
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||||
|
maximum acceptable input length for the model if that argument is not provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||||
|
different lengths).
|
||||||
|
max_length (:obj:`int`, `optional`):
|
||||||
|
Controls the maximum length to use by one of the truncation/padding parameters.
|
||||||
|
|
||||||
|
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
|
||||||
|
length is required by one of the truncation/padding parameters. If the model has no specific maximum
|
||||||
|
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
|
||||||
|
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||||
|
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
|
||||||
|
the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
|
||||||
|
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
|
||||||
|
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||||
|
|
||||||
|
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
|
||||||
|
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
|
||||||
|
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
|
||||||
|
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not to print more information and warnings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||||
|
"""
|
||||||
|
Constructs a Wav2Vec2 tokenizer.
|
||||||
|
|
||||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains some of the main methods.
|
||||||
|
Users should refer to the superclass for more information regarding such methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file (:obj:`str`):
|
||||||
|
File containing the vocabulary.
|
||||||
|
bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
|
||||||
|
The beginning of sentence token.
|
||||||
|
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
|
||||||
|
The end of sentence token.
|
||||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
|
||||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||||
|
token instead.
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
|
||||||
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
|
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
|
||||||
|
The token used for defining the end of a word.
|
||||||
|
**kwargs
|
||||||
|
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
|
pretrained_vocab_files_map = {
|
||||||
|
"vocab_file": {
|
||||||
|
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
|
||||||
|
},
|
||||||
|
"tokenizer_config_file": {
|
||||||
|
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
model_input_names = ["input_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
bos_token="<s>",
|
||||||
|
eos_token="</s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
word_delimiter_token="|",
|
||||||
|
do_lower_case=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
unk_token=unk_token,
|
||||||
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
do_lower_case=do_lower_case,
|
||||||
|
word_delimiter_token=word_delimiter_token,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self._word_delimiter_token = word_delimiter_token
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
|
self.encoder = json.load(vocab_handle)
|
||||||
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_delimiter_token(self) -> str:
|
||||||
|
"""
|
||||||
|
:obj:`str`: Padding token. Log an error if used while not having been set.
|
||||||
|
"""
|
||||||
|
if self._word_delimiter_token is None and self.verbose:
|
||||||
|
logger.error("Using word_delimiter_token, but it is not set yet.")
|
||||||
|
return None
|
||||||
|
return str(self._word_delimiter_token)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_delimiter_token_id(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
:obj:`Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns :obj:`None` if the token has
|
||||||
|
not been set.
|
||||||
|
"""
|
||||||
|
if self._word_delimiter_token is None:
|
||||||
|
return None
|
||||||
|
return self.convert_tokens_to_ids(self.word_delimiter_token)
|
||||||
|
|
||||||
|
@word_delimiter_token.setter
|
||||||
|
def word_delimiter_token(self, value):
|
||||||
|
self._word_delimiter_token = value
|
||||||
|
|
||||||
|
@word_delimiter_token_id.setter
|
||||||
|
def word_delimiter_token_id(self, value):
|
||||||
|
self._word_delimiter_token = self.convert_tokens_to_ids(value)
|
||||||
|
|
||||||
|
@add_end_docstrings(WAV2VEC2_KWARGS_DOCSTRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_to_multiple_of: Optional[int] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
verbose: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> BatchEncoding:
|
||||||
|
"""
|
||||||
|
Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
|
||||||
|
sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_speech (:obj:`np.ndarray`, :obj:`List[float]`, :obj:`List[np.ndarray]`, :obj:`List[List[float]]`):
|
||||||
|
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||||
|
values, a list of numpy arrayr or a list of list of float values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_batched = bool(
|
||||||
|
isinstance(raw_speech, (list, tuple))
|
||||||
|
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure input is in list format
|
||||||
|
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
||||||
|
raw_speech = [np.asarray(speech) for speech in raw_speech]
|
||||||
|
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||||
|
raw_speech = np.asarray(raw_speech)
|
||||||
|
|
||||||
|
# always return batch
|
||||||
|
if not is_batched:
|
||||||
|
raw_speech = [raw_speech]
|
||||||
|
|
||||||
|
# convert into correct format for padding
|
||||||
|
encoded_inputs = BatchEncoding({"input_values": raw_speech})
|
||||||
|
|
||||||
|
padded_inputs = self.pad(
|
||||||
|
encoded_inputs,
|
||||||
|
padding=padding,
|
||||||
|
max_length=max_length,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
return_attention_mask=False,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
return padded_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return len(self.decoder)
|
||||||
|
|
||||||
|
def get_vocab(self) -> Dict:
|
||||||
|
return dict(self.encoder, **self.added_tokens_encoder)
|
||||||
|
|
||||||
|
def _convert_token_to_id(self, token: str) -> int:
|
||||||
|
"""Converts a token (str) in an index (integer) using the vocab."""
|
||||||
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
||||||
|
|
||||||
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
|
result = self.decoder.get(index, self.unk_token)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
|
||||||
|
"""
|
||||||
|
# group same tokens into non-repeating tokens in CTC style decoding
|
||||||
|
grouped_tokens = [token_group[0] for token_group in groupby(tokens)]
|
||||||
|
|
||||||
|
# filter self.pad_token which is used as CTC-blank token
|
||||||
|
filtered_tokens = list(filter(lambda token: token != self.pad_token, grouped_tokens))
|
||||||
|
|
||||||
|
# replace delimiter token
|
||||||
|
string = "".join([" " if token == self.word_delimiter_token else token for token in filtered_tokens]).strip()
|
||||||
|
|
||||||
|
if self.do_lower_case:
|
||||||
|
string = string.lower()
|
||||||
|
return string
|
||||||
|
|
||||||
|
def _decode(
|
||||||
|
self,
|
||||||
|
token_ids: List[int],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
clean_up_tokenization_spaces: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
|
||||||
|
same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on
|
||||||
|
the whole token list and not individually on added tokens
|
||||||
|
"""
|
||||||
|
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for token in filtered_tokens:
|
||||||
|
if skip_special_tokens and token in self.all_special_ids:
|
||||||
|
continue
|
||||||
|
result.append(token)
|
||||||
|
|
||||||
|
text = self.convert_tokens_to_string(result)
|
||||||
|
|
||||||
|
if clean_up_tokenization_spaces:
|
||||||
|
clean_text = self.clean_up_tokenization(text)
|
||||||
|
return clean_text
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||||
|
if not os.path.isdir(save_directory):
|
||||||
|
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
|
||||||
|
return
|
||||||
|
vocab_file = os.path.join(
|
||||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||||
|
|
||||||
|
return (vocab_file,)
|
@ -31,6 +31,7 @@ from .file_utils import (
|
|||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
is_scatter_available,
|
is_scatter_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
|
is_soundfile_availble,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@ -367,6 +368,19 @@ def require_ray(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_soundfile(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires soundfile
|
||||||
|
|
||||||
|
These tests are skipped when soundfile isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not is_soundfile_availble():
|
||||||
|
return unittest.skip("test requires soundfile")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_count():
|
def get_gpu_count():
|
||||||
"""
|
"""
|
||||||
Return the number of available gpus (regardless of whether torch or tf is used)
|
Return the number of available gpus (regardless of whether torch or tf is used)
|
||||||
|
@ -2196,6 +2196,36 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
|
|||||||
requires_pytorch(load_tf_weights_in_transfo_xl)
|
requires_pytorch(load_tf_weights_in_transfo_xl)
|
||||||
|
|
||||||
|
|
||||||
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ForMaskedLM:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2Model:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2PreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
354
tests/test_modeling_wav2vec2.py
Normal file
354
tests/test_modeling_wav2vec2.py
Normal file
@ -0,0 +1,354 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||||
|
|
||||||
|
|
||||||
|
import math
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tests.test_modeling_common import floats_tensor
|
||||||
|
from transformers import is_torch_available
|
||||||
|
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
from .test_configuration_common import ConfigTester
|
||||||
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=13,
|
||||||
|
seq_length=1024, # speech is longer
|
||||||
|
is_training=False,
|
||||||
|
hidden_size=16,
|
||||||
|
feat_extract_norm="group",
|
||||||
|
feat_extract_dropout=0.0,
|
||||||
|
feat_extract_activation="gelu",
|
||||||
|
conv_dim=(32, 32, 32),
|
||||||
|
conv_stride=(4, 4, 4),
|
||||||
|
conv_kernel=(8, 8, 8),
|
||||||
|
conv_bias=False,
|
||||||
|
num_conv_pos_embeddings=16,
|
||||||
|
num_conv_pos_embedding_groups=2,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=2,
|
||||||
|
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||||
|
intermediate_size=20,
|
||||||
|
layer_norm_eps=1e-5,
|
||||||
|
hidden_act="gelu",
|
||||||
|
initializer_range=0.02,
|
||||||
|
vocab_size=32,
|
||||||
|
do_stable_layer_norm=False,
|
||||||
|
scope=None,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.is_training = is_training
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.feat_extract_norm = feat_extract_norm
|
||||||
|
self.feat_extract_dropout = feat_extract_dropout
|
||||||
|
self.feat_extract_activation = feat_extract_activation
|
||||||
|
self.conv_dim = conv_dim
|
||||||
|
self.conv_stride = conv_stride
|
||||||
|
self.conv_kernel = conv_kernel
|
||||||
|
self.conv_bias = conv_bias
|
||||||
|
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||||
|
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
|
self.scope = scope
|
||||||
|
|
||||||
|
output_seq_length = self.seq_length
|
||||||
|
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||||
|
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
||||||
|
self.output_seq_length = int(math.ceil(output_seq_length))
|
||||||
|
self.encoder_seq_length = self.output_seq_length
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
|
||||||
|
config = Wav2Vec2Config(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
feat_extract_norm=self.feat_extract_norm,
|
||||||
|
feat_extract_dropout=self.feat_extract_dropout,
|
||||||
|
feat_extract_activation=self.feat_extract_activation,
|
||||||
|
conv_dim=self.conv_dim,
|
||||||
|
conv_stride=self.conv_stride,
|
||||||
|
conv_kernel=self.conv_kernel,
|
||||||
|
conv_bias=self.conv_bias,
|
||||||
|
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||||
|
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||||
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
layer_norm_eps=self.layer_norm_eps,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
initializer_range=self.initializer_range,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config, input_values
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, input_values):
|
||||||
|
model = Wav2Vec2Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(input_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config, input_values = self.prepare_config_and_inputs()
|
||||||
|
inputs_dict = {"input_values": input_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (
|
||||||
|
(
|
||||||
|
Wav2Vec2Model,
|
||||||
|
Wav2Vec2ForMaskedLM,
|
||||||
|
)
|
||||||
|
if is_torch_available()
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
test_pruning = False
|
||||||
|
test_headmasking = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Wav2Vec2ModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
# Wav2Vec2 has no inputs_embeds
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# `input_ids` is renamed to `input_values`
|
||||||
|
def test_forward_signature(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wav2Vec2 cannot resize token embeddings
|
||||||
|
# since it has no tokens embeddings
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wav2Vec2 has no inputs_embeds
|
||||||
|
# and thus the `get_input_embeddings` fn
|
||||||
|
# is not implemented
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if "conv.weight" in name:
|
||||||
|
self.assertTrue(
|
||||||
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||||
|
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
||||||
|
test_pruning = False
|
||||||
|
test_headmasking = False
|
||||||
|
test_torchscript = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = Wav2Vec2ModelTester(
|
||||||
|
self, conv_stride=(3, 3, 3), feat_extract_norm="layer", do_stable_layer_norm=True
|
||||||
|
)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=Wav2Vec2Config, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
# Wav2Vec2 has no inputs_embeds
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# `input_ids` is renamed to `input_values`
|
||||||
|
def test_forward_signature(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wav2Vec2 cannot resize token embeddings
|
||||||
|
# since it has no tokens embeddings
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Wav2Vec2 has no inputs_embeds
|
||||||
|
# and thus the `get_input_embeddings` fn
|
||||||
|
# is not implemented
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_initialization(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
configs_no_init = _config_zero_init(config)
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=configs_no_init)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if "conv.weight" in name:
|
||||||
|
self.assertTrue(
|
||||||
|
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||||
|
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertIn(
|
||||||
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||||
|
[0.0, 1.0],
|
||||||
|
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@slow
|
||||||
|
@require_datasets
|
||||||
|
@require_soundfile
|
||||||
|
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||||
|
def _load_datasamples(self, num_samples):
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
# map files to raw
|
||||||
|
def map_to_array(batch):
|
||||||
|
speech, _ = sf.read(batch["file"])
|
||||||
|
batch["speech"] = speech
|
||||||
|
return batch
|
||||||
|
|
||||||
|
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
ds = ds.select(range(num_samples)).map(map_to_array)
|
||||||
|
|
||||||
|
return ds["speech"][:num_samples]
|
||||||
|
|
||||||
|
def test_inference_masked_lm_normal(self):
|
||||||
|
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
model.to(torch_device)
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(1)
|
||||||
|
|
||||||
|
input_values = tokenizer(input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
|
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||||
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
|
def test_inference_masked_lm_normal_batched(self):
|
||||||
|
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
model.to(torch_device)
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(2)
|
||||||
|
|
||||||
|
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
|
"a man said to the universe sir i exist",
|
||||||
|
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
|
||||||
|
]
|
||||||
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
|
def test_inference_masked_lm_robust_batched(self):
|
||||||
|
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(4)
|
||||||
|
|
||||||
|
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_values).logits
|
||||||
|
|
||||||
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|
||||||
|
EXPECTED_TRANSCRIPTIONS = [
|
||||||
|
"a man said to the universe sir i exist",
|
||||||
|
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||||
|
"the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about",
|
||||||
|
"his instant panic was followed by a small sharp blow high on his chest",
|
||||||
|
]
|
||||||
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
301
tests/test_tokenization_wav2vec2.py
Normal file
301
tests/test_tokenization_wav2vec2.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for the Wav2Vec2 tokenizer."""
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|
||||||
|
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||||
|
"""Creates a random float32 tensor"""
|
||||||
|
if rng is None:
|
||||||
|
rng = global_rng
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for batch_idx in range(shape[0]):
|
||||||
|
values.append([])
|
||||||
|
for _ in range(shape[1]):
|
||||||
|
values[-1].append(rng.random() * scale)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2TokenizerTest(unittest.TestCase):
|
||||||
|
tokenizer_class = Wav2Vec2Tokenizer
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||||
|
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||||||
|
|
||||||
|
self.special_tokens_map = {"pad_token": "<pad>", "unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||||
|
|
||||||
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||||
|
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
||||||
|
fp.write(json.dumps(vocab_tokens) + "\n")
|
||||||
|
|
||||||
|
def get_tokenizer(self, **kwargs):
|
||||||
|
kwargs.update(self.special_tokens_map)
|
||||||
|
return Wav2Vec2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
|
def test_tokenizer_decode(self):
|
||||||
|
# TODO(PVP) - change to facebook
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
sample_ids = [
|
||||||
|
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||||
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||||
|
]
|
||||||
|
tokens = tokenizer.decode(sample_ids[0])
|
||||||
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
|
self.assertEqual(tokens, batch_tokens[0])
|
||||||
|
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||||
|
|
||||||
|
def test_tokenizer_decode_special(self):
|
||||||
|
# TODO(PVP) - change to facebook
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
sample_ids = [
|
||||||
|
[11, 5, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||||
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77],
|
||||||
|
]
|
||||||
|
sample_ids_2 = [
|
||||||
|
[11, 5, 5, 5, 5, 5, 15, 15, 15, tokenizer.pad_token_id, 15, 8, 98],
|
||||||
|
[
|
||||||
|
24,
|
||||||
|
22,
|
||||||
|
5,
|
||||||
|
tokenizer.pad_token_id,
|
||||||
|
tokenizer.pad_token_id,
|
||||||
|
tokenizer.pad_token_id,
|
||||||
|
tokenizer.word_delimiter_token_id,
|
||||||
|
24,
|
||||||
|
22,
|
||||||
|
5,
|
||||||
|
77,
|
||||||
|
tokenizer.word_delimiter_token_id,
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
|
batch_tokens_2 = tokenizer.batch_decode(sample_ids_2)
|
||||||
|
self.assertEqual(batch_tokens, batch_tokens_2)
|
||||||
|
self.assertEqual(batch_tokens, ["HELLO<unk>", "BYE BYE<unk>"])
|
||||||
|
|
||||||
|
def test_tokenizer_decode_added_tokens(self):
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
tokenizer.add_tokens(["!", "?"])
|
||||||
|
tokenizer.add_special_tokens({"cls_token": "$$$"})
|
||||||
|
|
||||||
|
sample_ids = [
|
||||||
|
[
|
||||||
|
11,
|
||||||
|
5,
|
||||||
|
15,
|
||||||
|
tokenizer.pad_token_id,
|
||||||
|
15,
|
||||||
|
8,
|
||||||
|
98,
|
||||||
|
32,
|
||||||
|
32,
|
||||||
|
33,
|
||||||
|
tokenizer.word_delimiter_token_id,
|
||||||
|
32,
|
||||||
|
32,
|
||||||
|
33,
|
||||||
|
34,
|
||||||
|
34,
|
||||||
|
],
|
||||||
|
[24, 22, 5, tokenizer.word_delimiter_token_id, 24, 22, 5, 77, tokenizer.pad_token_id, 34, 34],
|
||||||
|
]
|
||||||
|
batch_tokens = tokenizer.batch_decode(sample_ids)
|
||||||
|
|
||||||
|
self.assertEqual(batch_tokens, ["HELLO<unk>!?!?$$$", "BYE BYE<unk>$$$"])
|
||||||
|
|
||||||
|
def test_call(self):
|
||||||
|
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
# create three inputs of length 800, 1000, and 1200
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_sequences_1 = tokenizer(speech_inputs[0], return_tensors="np").input_values
|
||||||
|
encoded_sequences_2 = tokenizer(np_speech_inputs[0], return_tensors="np").input_values
|
||||||
|
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_sequences_1 = tokenizer(speech_inputs, return_tensors="np").input_values
|
||||||
|
encoded_sequences_2 = tokenizer(np_speech_inputs, return_tensors="np").input_values
|
||||||
|
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||||
|
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||||
|
|
||||||
|
def test_padding(self, max_length=50):
|
||||||
|
def _input_values_have_equal_length(input_values):
|
||||||
|
length = len(input_values[0])
|
||||||
|
for input_values_slice in input_values[1:]:
|
||||||
|
if len(input_values_slice) != length:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _input_values_are_equal(input_values_1, input_values_2):
|
||||||
|
if len(input_values_1) != len(input_values_2):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for input_values_slice_1, input_values_slice_2 in zip(input_values_1, input_values_2):
|
||||||
|
if not np.allclose(np.asarray(input_values_slice_1), np.asarray(input_values_slice_2), atol=1e-3):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
|
||||||
|
input_values_1 = tokenizer(speech_inputs).input_values
|
||||||
|
input_values_2 = tokenizer(speech_inputs, padding="longest").input_values
|
||||||
|
input_values_3 = tokenizer(speech_inputs, padding="longest", max_length=1600).input_values
|
||||||
|
|
||||||
|
self.assertFalse(_input_values_have_equal_length(input_values_1))
|
||||||
|
self.assertTrue(_input_values_have_equal_length(input_values_2))
|
||||||
|
self.assertTrue(_input_values_have_equal_length(input_values_3))
|
||||||
|
self.assertTrue(_input_values_are_equal(input_values_2, input_values_3))
|
||||||
|
self.assertTrue(len(input_values_1[0]) == 800)
|
||||||
|
self.assertTrue(len(input_values_2[0]) == 1200)
|
||||||
|
# padding should be 0.0
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_2[0])[800:])) < 1e-3)
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_2[1])[1000:])) < 1e-3)
|
||||||
|
|
||||||
|
input_values_4 = tokenizer(speech_inputs, padding="max_length").input_values
|
||||||
|
input_values_5 = tokenizer(speech_inputs, padding="max_length", max_length=1600).input_values
|
||||||
|
|
||||||
|
self.assertTrue(_input_values_are_equal(input_values_1, input_values_4))
|
||||||
|
self.assertTrue(input_values_5.shape, (3, 1600))
|
||||||
|
# padding should be 0.0
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_5[0])[800:1200])) < 1e-3)
|
||||||
|
|
||||||
|
input_values_6 = tokenizer(speech_inputs, pad_to_multiple_of=500).input_values
|
||||||
|
input_values_7 = tokenizer(speech_inputs, padding="longest", pad_to_multiple_of=500).input_values
|
||||||
|
input_values_8 = tokenizer(
|
||||||
|
speech_inputs, padding="max_length", pad_to_multiple_of=500, max_length=2400
|
||||||
|
).input_values
|
||||||
|
|
||||||
|
self.assertTrue(_input_values_are_equal(input_values_1, input_values_6))
|
||||||
|
self.assertTrue(input_values_7.shape, (3, 1500))
|
||||||
|
self.assertTrue(input_values_8.shape, (3, 2500))
|
||||||
|
# padding should be 0.0
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_7[0])[800:])) < 1e-3)
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_7[1])[1000:])) < 1e-3)
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_7[2])[1200:])) < 1e-3)
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_8[0])[800:])) < 1e-3)
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_8[1])[1000:])) < 1e-3)
|
||||||
|
self.assertTrue(abs(sum(np.asarray(input_values_8[2])[1200:])) < 1e-3)
|
||||||
|
|
||||||
|
def test_save_pretrained(self):
|
||||||
|
pretrained_name = list(self.tokenizer_class.pretrained_vocab_files_map["vocab_file"].keys())[0]
|
||||||
|
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name)
|
||||||
|
tmpdirname2 = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
tokenizer_files = tokenizer.save_pretrained(tmpdirname2)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
sorted(tuple(VOCAB_FILES_NAMES.values()) + ("special_tokens_map.json", "added_tokens.json")),
|
||||||
|
sorted(tuple(x.split("/")[-1] for x in tokenizer_files)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checks everything loads correctly in the same way
|
||||||
|
tokenizer_p = self.tokenizer_class.from_pretrained(tmpdirname2)
|
||||||
|
|
||||||
|
# Check special tokens are set accordingly on Rust and Python
|
||||||
|
for key in tokenizer.special_tokens_map:
|
||||||
|
self.assertTrue(key in tokenizer_p.special_tokens_map)
|
||||||
|
|
||||||
|
shutil.rmtree(tmpdirname2)
|
||||||
|
|
||||||
|
def test_get_vocab(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
vocab_dict = tokenizer.get_vocab()
|
||||||
|
self.assertIsInstance(vocab_dict, dict)
|
||||||
|
self.assertGreaterEqual(len(tokenizer), len(vocab_dict))
|
||||||
|
|
||||||
|
vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
|
||||||
|
self.assertEqual(len(vocab), len(tokenizer))
|
||||||
|
|
||||||
|
tokenizer.add_tokens(["asdfasdfasdfasdf"])
|
||||||
|
vocab = [tokenizer.convert_ids_to_tokens(i) for i in range(len(tokenizer))]
|
||||||
|
self.assertEqual(len(vocab), len(tokenizer))
|
||||||
|
|
||||||
|
def test_save_and_load_tokenizer(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
# Isolate this from the other tests because we save additional tokens/etc
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
sample_ids = [0, 1, 4, 8, 9, 0, 12]
|
||||||
|
before_tokens = tokenizer.decode(sample_ids)
|
||||||
|
before_vocab = tokenizer.get_vocab()
|
||||||
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||||
|
after_tokens = after_tokenizer.decode(sample_ids)
|
||||||
|
after_vocab = after_tokenizer.get_vocab()
|
||||||
|
|
||||||
|
self.assertEqual(before_tokens, after_tokens)
|
||||||
|
self.assertDictEqual(before_vocab, after_vocab)
|
||||||
|
|
||||||
|
shutil.rmtree(tmpdirname)
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
# Isolate this from the other tests because we save additional tokens/etc
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
|
||||||
|
before_len = len(tokenizer)
|
||||||
|
sample_ids = [0, 1, 4, 8, 9, 0, 12, before_len, before_len + 1, before_len + 2]
|
||||||
|
tokenizer.add_tokens(["?", "!"])
|
||||||
|
additional_special_tokens = tokenizer.additional_special_tokens
|
||||||
|
additional_special_tokens.append("&")
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
|
||||||
|
before_tokens = tokenizer.decode(sample_ids)
|
||||||
|
before_vocab = tokenizer.get_vocab()
|
||||||
|
tokenizer.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||||||
|
after_tokens = after_tokenizer.decode(sample_ids)
|
||||||
|
after_vocab = after_tokenizer.get_vocab()
|
||||||
|
|
||||||
|
self.assertEqual(before_tokens, after_tokens)
|
||||||
|
self.assertDictEqual(before_vocab, after_vocab)
|
||||||
|
|
||||||
|
self.assertTrue(len(tokenizer), before_len + 3)
|
||||||
|
self.assertTrue(len(tokenizer), len(after_tokenizer))
|
||||||
|
shutil.rmtree(tmpdirname)
|
||||||
|
|
||||||
|
def test_tokenizer_slow_store_full_signature(self):
|
||||||
|
signature = inspect.signature(self.tokenizer_class.__init__)
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
for parameter_name, parameter in signature.parameters.items():
|
||||||
|
if parameter.default != inspect.Parameter.empty:
|
||||||
|
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
Loading…
Reference in New Issue
Block a user