mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Hubert (#11889)
* fix_torch_device_generate_test * remove @ * add hubert * add first test file * more docs * fix bugs * fix bug * finish * finish * finish docstring * fix * fix * finalize * add to ignored * finish * Apply suggestions from code review * correct naming * finish * fix auto config * finish * correct convert script * Apply suggestions from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suggestions lysandre & suraj Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
c3c39f7e84
commit
ccca510276
@ -231,6 +231,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
||||
1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
|
||||
1. **[Hubert](https://huggingface.co/transformers/model_doc/hubert.html)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
|
||||
1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
|
||||
1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
@ -186,98 +186,101 @@ Supported models
|
||||
Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
29. :doc:`GPT Neo <model_doc/gpt_neo>` (from EleutherAI) released in the repository `EleutherAI/gpt-neo
|
||||
<https://github.com/EleutherAI/gpt-neo>`__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
|
||||
30. :doc:`I-BERT <model_doc/ibert>` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization
|
||||
30. :doc:`Hubert <model_doc/hubert>` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech
|
||||
Representation Learning by Masked Prediction of Hidden Units <https://arxiv.org/abs/2106.07447>`__ by Wei-Ning Hsu,
|
||||
Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
|
||||
31. :doc:`I-BERT <model_doc/ibert>` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization
|
||||
<https://arxiv.org/abs/2101.01321>`__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
|
||||
31. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
32. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
of Text and Layout for Document Image Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li,
|
||||
Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
32. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
33. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
33. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
34. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
34. :doc:`LUKE <model_doc/luke>` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity
|
||||
35. :doc:`LUKE <model_doc/luke>` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity
|
||||
Representations with Entity-aware Self-attention <https://arxiv.org/abs/2010.01057>`__ by Ikuya Yamada, Akari Asai,
|
||||
Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
|
||||
35. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
36. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
Encoder Representations from Transformers for Open-Domain Question Answering <https://arxiv.org/abs/1908.07490>`__
|
||||
by Hao Tan and Mohit Bansal.
|
||||
36. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
37. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
|
||||
Machine Translation <https://arxiv.org/abs/2010.11125>`__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi
|
||||
Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman
|
||||
Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
|
||||
37. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
38. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
Jörg Tiedemann. The `Marian Framework <https://marian-nmt.github.io/>`__ is being developed by the Microsoft
|
||||
Translator Team.
|
||||
38. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
39. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
||||
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
39. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
40. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
|
||||
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
|
||||
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
|
||||
40. :doc:`Megatron-BERT <model_doc/megatron_bert>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
41. :doc:`Megatron-BERT <model_doc/megatron_bert>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
Multi-Billion Parameter Language Models Using Model Parallelism <https://arxiv.org/abs/1909.08053>`__ by Mohammad
|
||||
Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
41. :doc:`Megatron-GPT2 <model_doc/megatron_gpt2>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
42. :doc:`Megatron-GPT2 <model_doc/megatron_gpt2>` (from NVIDIA) released with the paper `Megatron-LM: Training
|
||||
Multi-Billion Parameter Language Models Using Model Parallelism <https://arxiv.org/abs/1909.08053>`__ by Mohammad
|
||||
Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
|
||||
42. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
43. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
||||
Jianfeng Lu, Tie-Yan Liu.
|
||||
43. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
44. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
||||
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
44. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
45. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
||||
Mohammad Saleh and Peter J. Liu.
|
||||
45. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
46. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
||||
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
46. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
47. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
47. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
48. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
48. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
|
||||
49. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
|
||||
Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and
|
||||
Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
49. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
50. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
|
||||
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
|
||||
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
|
||||
50. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
51. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||
Krishna, and Kurt W. Keutzer.
|
||||
51. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
52. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
52. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
53. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
53. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
54. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
54. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
55. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
|
||||
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
|
||||
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
|
||||
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
|
||||
55. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
56. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
|
||||
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
|
||||
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
|
||||
56. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
57. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
|
||||
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
|
||||
Zhou, Abdelrahman Mohamed, Michael Auli.
|
||||
57. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
58. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
58. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
59. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
59. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
60. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
60. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
61. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
61. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
62. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
|
||||
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
|
||||
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
|
||||
@ -345,6 +348,8 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Hubert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
@ -534,6 +539,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
model_doc/gpt
|
||||
model_doc/gpt2
|
||||
model_doc/gpt_neo
|
||||
model_doc/hubert
|
||||
model_doc/pegasus
|
||||
model_doc/phobert
|
||||
model_doc/prophetnet
|
||||
|
65
docs/source/model_doc/hubert.rst
Normal file
65
docs/source/model_doc/hubert.rst
Normal file
@ -0,0 +1,65 @@
|
||||
..
|
||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Hubert
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units
|
||||
<https://arxiv.org/abs/2106.07447>`__ by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan
|
||||
Salakhutdinov, Abdelrahman Mohamed.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Self-supervised approaches for speech representation learning are challenged by three unique problems: (1) there are
|
||||
multiple sound units in each input utterance, (2) there is no lexicon of input sound units during the pre-training
|
||||
phase, and (3) sound units have variable lengths with no explicit segmentation. To deal with these three problems, we
|
||||
propose the Hidden-Unit BERT (HuBERT) approach for self-supervised speech representation learning, which utilizes an
|
||||
offline clustering step to provide aligned target labels for a BERT-like prediction loss. A key ingredient of our
|
||||
approach is applying the prediction loss over the masked regions only, which forces the model to learn a combined
|
||||
acoustic and language model over the continuous inputs. HuBERT relies primarily on the consistency of the unsupervised
|
||||
clustering step rather than the intrinsic quality of the assigned cluster labels. Starting with a simple k-means
|
||||
teacher of 100 clusters, and using two iterations of clustering, the HuBERT model either matches or improves upon the
|
||||
state-of-the-art wav2vec 2.0 performance on the Librispeech (960h) and Libri-light (60,000h) benchmarks with 10min, 1h,
|
||||
10h, 100h, and 960h fine-tuning subsets. Using a 1B parameter model, HuBERT shows up to 19% and 13% relative WER
|
||||
reduction on the more challenging dev-other and test-other evaluation subsets.*
|
||||
|
||||
Tips:
|
||||
|
||||
- Hubert is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
|
||||
- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded
|
||||
using :class:`~transformers.Wav2Vec2CTCTokenizer`.
|
||||
|
||||
This model was contributed by `patrickvonplaten <https://huggingface.co/patrickvonplaten>`__.
|
||||
|
||||
|
||||
HubertConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HubertConfig
|
||||
:members:
|
||||
|
||||
|
||||
HubertModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HubertModel
|
||||
:members: forward
|
||||
|
||||
|
||||
HubertForCTC
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HubertForCTC
|
||||
:members: forward
|
@ -201,6 +201,7 @@ _import_structure = {
|
||||
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
|
||||
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
|
||||
"models.herbert": ["HerbertTokenizer"],
|
||||
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
|
||||
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
||||
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"],
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
@ -777,6 +778,14 @@ if is_torch_available():
|
||||
"load_tf_weights_in_gpt_neo",
|
||||
]
|
||||
)
|
||||
_import_structure["models.hubert"].extend(
|
||||
[
|
||||
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"HubertForCTC",
|
||||
"HubertModel",
|
||||
"HubertPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.ibert"].extend(
|
||||
[
|
||||
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -1742,6 +1751,7 @@ if TYPE_CHECKING:
|
||||
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
|
||||
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from .models.herbert import HerbertTokenizer
|
||||
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
||||
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
@ -2230,6 +2240,12 @@ if TYPE_CHECKING:
|
||||
GPTNeoPreTrainedModel,
|
||||
load_tf_weights_in_gpt_neo,
|
||||
)
|
||||
from .models.hubert import (
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
HubertForCTC,
|
||||
HubertModel,
|
||||
HubertPreTrainedModel,
|
||||
)
|
||||
from .models.ibert import (
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
IBertForMaskedLM,
|
||||
|
@ -49,6 +49,7 @@ from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTCo
|
||||
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
|
||||
from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
||||
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
|
||||
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
@ -144,6 +145,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
@ -193,6 +195,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("flaubert", FlaubertConfig),
|
||||
("fsmt", FSMTConfig),
|
||||
("squeezebert", SqueezeBertConfig),
|
||||
("hubert", HubertConfig),
|
||||
("bert", BertConfig),
|
||||
("openai-gpt", OpenAIGPTConfig),
|
||||
("gpt2", GPT2Config),
|
||||
@ -274,6 +277,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("mt5", "mT5"),
|
||||
("mpnet", "MPNet"),
|
||||
("tapas", "TAPAS"),
|
||||
("hubert", "Hubert"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -147,6 +147,7 @@ from ..funnel.modeling_funnel import (
|
||||
)
|
||||
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
|
||||
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel
|
||||
from ..hubert.modeling_hubert import HubertModel
|
||||
from ..ibert.modeling_ibert import (
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
@ -327,6 +328,7 @@ from .configuration_auto import (
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
GPTNeoConfig,
|
||||
HubertConfig,
|
||||
IBertConfig,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
@ -380,6 +382,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(Speech2TextConfig, Speech2TextModel),
|
||||
(ViTConfig, ViTModel),
|
||||
(Wav2Vec2Config, Wav2Vec2Model),
|
||||
(HubertConfig, HubertModel),
|
||||
(M2M100Config, M2M100Model),
|
||||
(ConvBertConfig, ConvBertModel),
|
||||
(LEDConfig, LEDModel),
|
||||
|
64
src/transformers/models/hubert/__init__.py
Normal file
64
src/transformers/models/hubert/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_hubert"] = [
|
||||
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"HubertForCTC",
|
||||
"HubertModel",
|
||||
"HubertPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_hubert import (
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
HubertForCTC,
|
||||
HubertModel,
|
||||
HubertPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
222
src/transformers/models/hubert/configuration_hubert.py
Normal file
222
src/transformers/models/hubert/configuration_hubert.py
Normal file
@ -0,0 +1,222 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Hubert model configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json",
|
||||
# See all Hubert models at https://huggingface.co/models?filter=hubert
|
||||
}
|
||||
|
||||
|
||||
class HubertConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.HubertModel`. It is used to
|
||||
instantiate an Hubert model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the Hubert
|
||||
`facebook/hubert-base-ls960 <https://huggingface.co/facebook/hubert-base-ls960>`__ architecture.
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 32):
|
||||
Vocabulary size of the Hubert model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.HubertModel`. Vocabulary size of the model.
|
||||
Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of
|
||||
:class:`~transformers.HubertModel`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`):
|
||||
The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group
|
||||
normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D
|
||||
convolutional layers.
|
||||
feat_extract_dropout (:obj:`float`, `optional`, defaults to 0.0):
|
||||
The dropout probabilitiy for all 1D convolutional layers in feature extractor.
|
||||
feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
||||
extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
|
||||
conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`):
|
||||
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
||||
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
|
||||
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
|
||||
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
|
||||
of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`.
|
||||
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
|
||||
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
|
||||
length of `conv_kernel` defines the number of convolutional layers and has to match the the length of
|
||||
`conv_dim`.
|
||||
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether the 1D convolutional layers have a bias.
|
||||
num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128):
|
||||
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
||||
embeddings layer.
|
||||
num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16):
|
||||
Number of groups of 1D convolutional positional embeddings layer.
|
||||
do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
|
||||
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
|
||||
False`` corresponds to applying layer norm after the attention layer.
|
||||
apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see
|
||||
`SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
<https://arxiv.org/abs/1904.08779>`__.
|
||||
mask_time_prob (:obj:`float`, `optional`, defaults to 0.05):
|
||||
Propability of each feature vector along the time axis to be chosen as the start of the vector span to be
|
||||
masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_time_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the time axis.
|
||||
mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Propability of each feature vector along the feature axis to be chosen as the start of the vector span to
|
||||
be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the feature axis.
|
||||
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
|
||||
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
|
||||
instance of :class:`~transformers.HubertForCTC`.
|
||||
ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.HubertForCTC`.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import HubertModel, HubertConfig
|
||||
|
||||
>>> # Initializing a Hubert facebook/hubert-base-ls960 style configuration
|
||||
>>> configuration = HubertConfig()
|
||||
|
||||
>>> # Initializing a model from the facebook/hubert-base-ls960 style configuration
|
||||
>>> model = HubertModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
model_type = "hubert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout=0.1,
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
feat_proj_dropout=0.1,
|
||||
final_dropout=0.1,
|
||||
layerdrop=0.1,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
feat_extract_norm="group",
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
||||
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
||||
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=128,
|
||||
num_conv_pos_embedding_groups=16,
|
||||
do_stable_layer_norm=False,
|
||||
apply_spec_augment=True,
|
||||
mask_time_prob=0.05,
|
||||
mask_time_length=10,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = list(conv_dim)
|
||||
self.conv_stride = list(conv_stride)
|
||||
self.conv_kernel = list(conv_kernel)
|
||||
self.conv_bias = conv_bias
|
||||
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||
self.num_feat_extract_layers = len(self.conv_dim)
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.feat_proj_dropout = feat_proj_dropout
|
||||
self.final_dropout = final_dropout
|
||||
self.layerdrop = layerdrop
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
||||
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
||||
):
|
||||
raise ValueError(
|
||||
"Configuration for convolutional layers is incorrect."
|
||||
"It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`,"
|
||||
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
|
||||
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
||||
)
|
||||
|
||||
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
||||
self.apply_spec_augment = apply_spec_augment
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.mask_feature_prob = mask_feature_prob
|
||||
self.mask_feature_length = mask_feature_length
|
||||
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
@ -0,0 +1,244 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Hubert checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
from transformers import (
|
||||
HubertConfig,
|
||||
HubertForCTC,
|
||||
HubertModel,
|
||||
Wav2Vec2CTCTokenizer,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAPPING = {
|
||||
"post_extract_proj": "feature_projection.projection",
|
||||
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
||||
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
|
||||
|
||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
assert (
|
||||
hf_shape == value.shape
|
||||
), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||
|
||||
if weight_type == "weight":
|
||||
hf_pointer.weight.data = value
|
||||
elif weight_type == "weight_g":
|
||||
hf_pointer.weight_g.data = value
|
||||
elif weight_type == "weight_v":
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
else:
|
||||
hf_pointer.data = value
|
||||
|
||||
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
feature_extractor = hf_model.hubert.feature_extractor if is_finetuned else hf_model.feature_extractor
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
if "conv_layers" in name:
|
||||
load_conv_layer(
|
||||
name,
|
||||
value,
|
||||
feature_extractor,
|
||||
unused_weights,
|
||||
hf_model.config.feat_extract_norm == "group",
|
||||
)
|
||||
is_used = True
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "hubert." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
|
||||
if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned):
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
mapped_key = mapped_key.replace("*", layer_index)
|
||||
if "weight_g" in name:
|
||||
weight_type = "weight_g"
|
||||
elif "weight_v" in name:
|
||||
weight_type = "weight_v"
|
||||
elif "weight" in name:
|
||||
weight_type = "weight"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
continue
|
||||
if not is_used:
|
||||
unused_weights.append(name)
|
||||
|
||||
logger.warning(f"Unused weights: {unused_weights}")
|
||||
|
||||
|
||||
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
||||
name = full_name.split("conv_layers.")[-1]
|
||||
items = name.split(".")
|
||||
layer_id = int(items[0])
|
||||
type_id = int(items[1])
|
||||
|
||||
if type_id == 0:
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
||||
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
||||
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
||||
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
||||
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||
else:
|
||||
unused_weights.append(full_name)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_hubert_checkpoint(
|
||||
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
if config_path is not None:
|
||||
config = HubertConfig.from_pretrained(config_path)
|
||||
else:
|
||||
config = HubertConfig()
|
||||
|
||||
if is_finetuned:
|
||||
if dict_path:
|
||||
target_dict = Dictionary.load(dict_path)
|
||||
|
||||
# important change bos & pad token id since CTC symbol is <pad> and
|
||||
# not <s> as in fairseq
|
||||
config.bos_token_id = target_dict.pad_index
|
||||
config.pad_token_id = target_dict.bos_index
|
||||
config.eos_token_id = target_dict.eos_index
|
||||
config.vocab_size = len(target_dict.symbols)
|
||||
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
|
||||
if not os.path.isdir(pytorch_dump_folder_path):
|
||||
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
|
||||
return
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
|
||||
json.dump(target_dict.indices, vocab_handle)
|
||||
tokenizer = Wav2Vec2CTCTokenizer(
|
||||
vocab_path,
|
||||
unk_token=target_dict.unk_word,
|
||||
pad_token=target_dict.pad_word,
|
||||
bos_token=target_dict.bos_word,
|
||||
eos_token=target_dict.eos_word,
|
||||
word_delimiter_token="|",
|
||||
do_lower_case=False,
|
||||
)
|
||||
return_attention_mask = True if config.feat_extract_norm == "layer" else False
|
||||
feature_extractor = Wav2Vec2FeatureExtractor(
|
||||
feature_size=1,
|
||||
sampling_rate=16000,
|
||||
padding_value=0,
|
||||
do_normalize=True,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
||||
processor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
hf_wav2vec = HubertForCTC(config)
|
||||
else:
|
||||
hf_wav2vec = HubertModel(config)
|
||||
|
||||
if is_finetuned:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
|
||||
)
|
||||
else:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
|
||||
|
||||
model = model[0].eval()
|
||||
|
||||
recursively_load_weights(model, hf_wav2vec, is_finetuned)
|
||||
|
||||
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument(
|
||||
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_hubert_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
|
||||
)
|
1065
src/transformers/models/hubert/modeling_hubert.py
Executable file
1065
src/transformers/models/hubert/modeling_hubert.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -1709,6 +1709,32 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_gpt_neo, ["torch"])
|
||||
|
||||
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class HubertForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class HubertModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HubertPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -263,6 +263,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("Speech2TextConfig", "Speech2TextModel"),
|
||||
("ViTConfig", "ViTModel"),
|
||||
("Wav2Vec2Config", "Wav2Vec2Model"),
|
||||
("HubertConfig", "HubertModel"),
|
||||
("M2M100Config", "M2M100Model"),
|
||||
("ConvBertConfig", "ConvBertModel"),
|
||||
("LEDConfig", "LEDModel"),
|
||||
|
553
tests/test_modeling_hubert.py
Normal file
553
tests/test_modeling_hubert.py
Normal file
@ -0,0 +1,553 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Hubert model. """
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import HubertConfig, HubertForCTC, HubertModel, Wav2Vec2Processor
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
|
||||
class HubertModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=1024, # speech is longer
|
||||
is_training=False,
|
||||
hidden_size=16,
|
||||
feat_extract_norm="group",
|
||||
feat_extract_dropout=0.0,
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(32, 32, 32),
|
||||
conv_stride=(4, 4, 4),
|
||||
conv_kernel=(8, 8, 8),
|
||||
conv_bias=False,
|
||||
num_conv_pos_embeddings=16,
|
||||
num_conv_pos_embedding_groups=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=2,
|
||||
hidden_dropout_prob=0.1, # this is most likely not correctly set yet
|
||||
intermediate_size=20,
|
||||
layer_norm_eps=1e-5,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_dropout = feat_extract_dropout
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = conv_dim
|
||||
self.conv_stride = conv_stride
|
||||
self.conv_kernel = conv_kernel
|
||||
self.conv_bias = conv_bias
|
||||
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
||||
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.intermediate_size = intermediate_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||
output_seq_length = (output_seq_length - (kernel - 1)) / stride
|
||||
self.output_seq_length = int(math.ceil(output_seq_length))
|
||||
self.encoder_seq_length = self.output_seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = HubertConfig(
|
||||
hidden_size=self.hidden_size,
|
||||
feat_extract_norm=self.feat_extract_norm,
|
||||
feat_extract_dropout=self.feat_extract_dropout,
|
||||
feat_extract_activation=self.feat_extract_activation,
|
||||
conv_dim=self.conv_dim,
|
||||
conv_stride=self.conv_stride,
|
||||
conv_kernel=self.conv_kernel,
|
||||
conv_bias=self.conv_bias,
|
||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
intermediate_size=self.intermediate_size,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
return config, input_values, attention_mask
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
model = HubertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
model = HubertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0.0
|
||||
|
||||
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
|
||||
|
||||
for i in range(input_values.shape[0]):
|
||||
input_slice = input_values[i : i + 1, : input_lengths[i]]
|
||||
output = model(input_slice).last_hidden_state
|
||||
|
||||
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
||||
|
||||
def check_ctc_loss(self, config, input_values, *args):
|
||||
model = HubertForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
model.config.ctc_loss_reduction = "sum"
|
||||
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
model.config.ctc_loss_reduction = "mean"
|
||||
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3)
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = HubertForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze feature encoder
|
||||
model.freeze_feature_extractor()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
if max_length_labels[i] < labels.shape[-1]:
|
||||
# it's important that we make sure that target lenghts are at least
|
||||
# one shorter than logit lenghts to prevent -inf
|
||||
labels[i, max_length_labels[i] - 1 :] = -100
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = HubertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=HubertConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# `input_ids` is renamed to `input_values`
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
# Hubert cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
# set layer drop to 0
|
||||
model.config.layerdrop = 0.0
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
|
||||
input_lengths = torch.tensor(
|
||||
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
|
||||
)
|
||||
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
|
||||
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Encoder-/Decoder-only models
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
attentions = outputs.attentions[0]
|
||||
|
||||
hidden_states.retain_grad()
|
||||
attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = [
|
||||
"conv.weight",
|
||||
"masked_spec_embed",
|
||||
"quantizer.weight_proj.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight_g is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "weight_v") and module.weight_v is not None:
|
||||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = HubertModel.from_pretrained("facebook/hubert-base-ls960")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = HubertModelTester(
|
||||
self, conv_stride=(3, 3, 3), feat_extract_norm="layer", do_stable_layer_norm=True
|
||||
)
|
||||
self.config_tester = ConfigTester(self, config_class=HubertConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_batched_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# `input_ids` is renamed to `input_values`
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
# Hubert cannot resize token embeddings
|
||||
# since it has no tokens embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
pass
|
||||
|
||||
# Hubert has no inputs_embeds
|
||||
# and thus the `get_input_embeddings` fn
|
||||
# is not implemented
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
# set layer drop to 0
|
||||
model.config.layerdrop = 0.0
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
|
||||
input_lengths = torch.tensor(
|
||||
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
|
||||
)
|
||||
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
|
||||
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Encoder-/Decoder-only models
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
attentions = outputs.attentions[0]
|
||||
|
||||
hidden_states.retain_grad()
|
||||
attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
uniform_init_parms = [
|
||||
"conv.weight",
|
||||
"masked_spec_embed",
|
||||
"quantizer.weight_proj.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
else:
|
||||
self.assertIn(
|
||||
((param.data.mean() * 1e9).round() / 1e9).item(),
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(3)
|
||||
if hasattr(module, "weight_g") and module.weight_g is not None:
|
||||
module.weight_g.data.fill_(3)
|
||||
if hasattr(module, "weight_v") and module.weight_v is not None:
|
||||
module.weight_v.data.fill_(3)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
module.bias.data.fill_(3)
|
||||
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
|
||||
module.masked_spec_embed.data.fill_(3)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class HubertUtilsTest(unittest.TestCase):
|
||||
def test_compute_mask_indices(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
# because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_datasets
|
||||
@require_soundfile
|
||||
@slow
|
||||
class HubertModelIntegrationTest(unittest.TestCase):
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
ids = [f"1272-141231-000{i}" for i in range(num_samples)]
|
||||
|
||||
# map files to raw
|
||||
def map_to_array(batch):
|
||||
speech, _ = sf.read(batch["file"])
|
||||
batch["speech"] = speech
|
||||
return batch
|
||||
|
||||
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array)
|
||||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def test_inference_ctc_batched(self):
|
||||
model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
@ -119,6 +119,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"TFRagSequenceForGeneration",
|
||||
"TFRagTokenForGeneration",
|
||||
"Wav2Vec2ForCTC",
|
||||
"HubertForCTC",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLNetForQuestionAnswering",
|
||||
"SeparableConv1D",
|
||||
|
Loading…
Reference in New Issue
Block a user