diff --git a/README.md b/README.md index bb51eb0c0a5..5d8e2340a40 100644 --- a/README.md +++ b/README.md @@ -231,6 +231,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. 1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. 1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. +1. **[Hubert](https://huggingface.co/transformers/model_doc/hubert.html)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. 1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer 1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. 1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan. diff --git a/docs/source/index.rst b/docs/source/index.rst index b95e48340e9..678c896fd36 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -186,98 +186,101 @@ Supported models Luan, Dario Amodei** and Ilya Sutskever**. 29. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. -30. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +30. :doc:`Hubert ` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech + Representation Learning by Masked Prediction of Hidden Units `__ by Wei-Ning Hsu, + Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. +31. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -31. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +32. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -32. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +33. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -33. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +34. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -34. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity +35. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention `__ by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. -35. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +36. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -36. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +37. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -37. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +38. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -38. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +39. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -39. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +40. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -40. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +41. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -41. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +42. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -42. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +43. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -43. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +44. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -44. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +45. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -45. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +46. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -46. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +47. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -47. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +48. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -48. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: +49. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: Enhanced Transformer with Rotary Position Embedding `__ by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -49. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +50. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -50. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +51. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -51. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +52. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -52. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +53. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -53. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +54. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -54. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +55. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -55. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and +56. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and Performant Baseline for Vision and Language `__ by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. -56. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +57. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -57. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +58. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -58. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +59. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -59. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +60. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -60. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +61. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -61. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +62. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -345,6 +348,8 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | GPT Neo | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| Hubert | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | LED | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -534,6 +539,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/gpt model_doc/gpt2 model_doc/gpt_neo + model_doc/hubert model_doc/pegasus model_doc/phobert model_doc/prophetnet diff --git a/docs/source/model_doc/hubert.rst b/docs/source/model_doc/hubert.rst new file mode 100644 index 00000000000..a1e4e124522 --- /dev/null +++ b/docs/source/model_doc/hubert.rst @@ -0,0 +1,65 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +Hubert +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units +`__ by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan +Salakhutdinov, Abdelrahman Mohamed. + +The abstract from the paper is the following: + +*Self-supervised approaches for speech representation learning are challenged by three unique problems: (1) there are +multiple sound units in each input utterance, (2) there is no lexicon of input sound units during the pre-training +phase, and (3) sound units have variable lengths with no explicit segmentation. To deal with these three problems, we +propose the Hidden-Unit BERT (HuBERT) approach for self-supervised speech representation learning, which utilizes an +offline clustering step to provide aligned target labels for a BERT-like prediction loss. A key ingredient of our +approach is applying the prediction loss over the masked regions only, which forces the model to learn a combined +acoustic and language model over the continuous inputs. HuBERT relies primarily on the consistency of the unsupervised +clustering step rather than the intrinsic quality of the assigned cluster labels. Starting with a simple k-means +teacher of 100 clusters, and using two iterations of clustering, the HuBERT model either matches or improves upon the +state-of-the-art wav2vec 2.0 performance on the Librispeech (960h) and Libri-light (60,000h) benchmarks with 10min, 1h, +10h, 100h, and 960h fine-tuning subsets. Using a 1B parameter model, HuBERT shows up to 19% and 13% relative WER +reduction on the more challenging dev-other and test-other evaluation subsets.* + +Tips: + +- Hubert is a speech model that accepts a float array corresponding to the raw waveform of the speech signal. +- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded + using :class:`~transformers.Wav2Vec2CTCTokenizer`. + +This model was contributed by `patrickvonplaten `__. + + +HubertConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.HubertConfig + :members: + + +HubertModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.HubertModel + :members: forward + + +HubertForCTC +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.HubertForCTC + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f244d167535..9fcf97b119d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -201,6 +201,7 @@ _import_structure = { "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], "models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"], "models.herbert": ["HerbertTokenizer"], + "models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"], "models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], "models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"], "models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"], @@ -777,6 +778,14 @@ if is_torch_available(): "load_tf_weights_in_gpt_neo", ] ) + _import_structure["models.hubert"].extend( + [ + "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "HubertForCTC", + "HubertModel", + "HubertPreTrainedModel", + ] + ) _import_structure["models.ibert"].extend( [ "IBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1742,6 +1751,7 @@ if TYPE_CHECKING: from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig from .models.herbert import HerbertTokenizer + from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer @@ -2230,6 +2240,12 @@ if TYPE_CHECKING: GPTNeoPreTrainedModel, load_tf_weights_in_gpt_neo, ) + from .models.hubert import ( + HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + HubertForCTC, + HubertModel, + HubertPreTrainedModel, + ) from .models.ibert import ( IBERT_PRETRAINED_MODEL_ARCHIVE_LIST, IBertForMaskedLM, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c103622d698..76d99362164 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -49,6 +49,7 @@ from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTCo from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from ..gpt_neo.configuration_gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig +from ..hubert.configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig @@ -144,6 +145,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ] for key, value, in pretrained_map.items() ) @@ -193,6 +195,7 @@ CONFIG_MAPPING = OrderedDict( ("flaubert", FlaubertConfig), ("fsmt", FSMTConfig), ("squeezebert", SqueezeBertConfig), + ("hubert", HubertConfig), ("bert", BertConfig), ("openai-gpt", OpenAIGPTConfig), ("gpt2", GPT2Config), @@ -274,6 +277,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("mt5", "mT5"), ("mpnet", "MPNet"), ("tapas", "TAPAS"), + ("hubert", "Hubert"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ce8b3592df3..f67213cd2d3 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -147,6 +147,7 @@ from ..funnel.modeling_funnel import ( ) from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel +from ..hubert.modeling_hubert import HubertModel from ..ibert.modeling_ibert import ( IBertForMaskedLM, IBertForMultipleChoice, @@ -327,6 +328,7 @@ from .configuration_auto import ( FunnelConfig, GPT2Config, GPTNeoConfig, + HubertConfig, IBertConfig, LayoutLMConfig, LEDConfig, @@ -380,6 +382,7 @@ MODEL_MAPPING = OrderedDict( (Speech2TextConfig, Speech2TextModel), (ViTConfig, ViTModel), (Wav2Vec2Config, Wav2Vec2Model), + (HubertConfig, HubertModel), (M2M100Config, M2M100Model), (ConvBertConfig, ConvBertModel), (LEDConfig, LEDModel), diff --git a/src/transformers/models/hubert/__init__.py b/src/transformers/models/hubert/__init__.py new file mode 100644 index 00000000000..11f37eefeb2 --- /dev/null +++ b/src/transformers/models/hubert/__init__.py @@ -0,0 +1,64 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available + + +_import_structure = { + "configuration_hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"], +} + +if is_torch_available(): + _import_structure["modeling_hubert"] = [ + "HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "HubertForCTC", + "HubertModel", + "HubertPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig + + if is_torch_available(): + from .modeling_hubert import ( + HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + HubertForCTC, + HubertModel, + HubertPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py new file mode 100644 index 00000000000..f3d2f77ed02 --- /dev/null +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Hubert model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/hubert-base-ls960": "https://huggingface.co/facebook/hubert-base-ls960/resolve/main/config.json", + # See all Hubert models at https://huggingface.co/models?filter=hubert +} + + +class HubertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.HubertModel`. It is used to + instantiate an Hubert model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Hubert + `facebook/hubert-base-ls960 `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 32): + Vocabulary size of the Hubert model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.HubertModel`. Vocabulary size of the model. + Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of + :class:`~transformers.HubertModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + feat_extract_norm (:obj:`str`, `optional`, defaults to :obj:`"group"`): + The norm to be applied to 1D convolutional layers in feature extractor. One of :obj:`"group"` for group + normalization of only the first 1D convolutional layer or :obj:`"layer"` for layer normalization of all 1D + convolutional layers. + feat_extract_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all 1D convolutional layers in feature extractor. + feat_extract_activation (:obj:`str, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + conv_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. + conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length + of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`. + conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The + length of `conv_kernel` defines the number of convolutional layers and has to match the the length of + `conv_dim`. + conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (:obj:`int`, `optional`, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (:obj:`int`, `optional`, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + do_stable_layer_norm (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is + True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is + False`` corresponds to applying layer norm after the attention layer. + apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see + `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition + `__. + mask_time_prob (:obj:`float`, `optional`, defaults to 0.05): + Propability of each feature vector along the time axis to be chosen as the start of the vector span to be + masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be + masked along the time axis. This is only relevant if ``apply_spec_augment is True``. + mask_time_length (:obj:`int`, `optional`, defaults to 10): + Length of vector span along the time axis. + mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0): + Propability of each feature vector along the feature axis to be chosen as the start of the vector span to + be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be + masked along the time axis. This is only relevant if ``apply_spec_augment is True``. + mask_feature_length (:obj:`int`, `optional`, defaults to 10): + Length of vector span along the feature axis. + ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`): + Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an + instance of :class:`~transformers.HubertForCTC`. + ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses + mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an + instance of :class:`~transformers.HubertForCTC`. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import HubertModel, HubertConfig + + >>> # Initializing a Hubert facebook/hubert-base-ls960 style configuration + >>> configuration = HubertConfig() + + >>> # Initializing a model from the facebook/hubert-base-ls960 style configuration + >>> model = HubertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "hubert" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.1, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + do_stable_layer_norm=False, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_feature_prob=0.0, + mask_feature_length=10, + ctc_loss_reduction="sum", + ctc_zero_infinity=False, + gradient_checkpointing=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + **kwargs + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.gradient_checkpointing = gradient_checkpointing + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect." + "It is required that `len(config.conv_dim)` == `len(config.conv_stride)` == `len(config.conv_kernel)`," + f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)" + f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity diff --git a/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000..dee823e094d --- /dev/null +++ b/src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + + +import argparse +import json +import os + +import fairseq +import torch +from fairseq.data import Dictionary + +from transformers import ( + HubertConfig, + HubertForCTC, + HubertModel, + Wav2Vec2CTCTokenizer, + Wav2Vec2FeatureExtractor, + Wav2Vec2Processor, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", +} + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert ( + hf_shape == value.shape + ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}" + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights(fairseq_model, hf_model, is_finetuned): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.hubert.feature_extractor if is_finetuned else hf_model.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + mapped_key = "hubert." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key + + if key in name or (key.split("w2v_model.")[-1] == name.split(".")[0] and not is_finetuned): + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "weight" in name: + weight_type = "weight" + elif "bias" in name: + weight_type = "bias" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found." + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +@torch.no_grad() +def convert_hubert_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = HubertConfig.from_pretrained(config_path) + else: + config = HubertConfig() + + if is_finetuned: + if dict_path: + target_dict = Dictionary.load(dict_path) + + # important change bos & pad token id since CTC symbol is and + # not as in fairseq + config.bos_token_id = target_dict.pad_index + config.pad_token_id = target_dict.bos_index + config.eos_token_id = target_dict.eos_index + config.vocab_size = len(target_dict.symbols) + vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") + if not os.path.isdir(pytorch_dump_folder_path): + logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path)) + return + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + with open(vocab_path, "w", encoding="utf-8") as vocab_handle: + json.dump(target_dict.indices, vocab_handle) + tokenizer = Wav2Vec2CTCTokenizer( + vocab_path, + unk_token=target_dict.unk_word, + pad_token=target_dict.pad_word, + bos_token=target_dict.bos_word, + eos_token=target_dict.eos_word, + word_delimiter_token="|", + do_lower_case=False, + ) + return_attention_mask = True if config.feat_extract_norm == "layer" else False + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=return_attention_mask, + ) + processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + processor.save_pretrained(pytorch_dump_folder_path) + + hf_wav2vec = HubertForCTC(config) + else: + hf_wav2vec = HubertModel(config) + + if is_finetuned: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} + ) + else: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) + + model = model[0].eval() + + recursively_load_weights(model, hf_wav2vec, is_finetuned) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) + args = parser.parse_args() + convert_hubert_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned + ) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py new file mode 100755 index 00000000000..cad377eb666 --- /dev/null +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -0,0 +1,1065 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Hubert model. """ + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.deepspeed import is_deepspeed_zero3_enabled + +from ...activations import ACT2FN +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import BaseModelOutput, CausalLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_hubert import HubertConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "HubertConfig" + +HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/hubert-base-ls960", + # See all Hubert models at https://huggingface.co/models?filter=hubert +] + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.tensor: + """ + Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for + ASR `__. + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_length: size of the mask + min_masks: minimum number of masked spans + + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device) + + # get random indices to mask + spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + spec_aug_mask_idxs = ( + spec_aug_mask_idxs.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # scatter indices to mask + spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert +class HubertNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert +class HubertLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert +class HubertGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert +class HubertPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + deepspeed.zero.register_external_parameter(self, self.conv.weight_v) + deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + else: + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Hubert +class HubertSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureExtractor with Wav2Vec2->Hubert +class HubertFeatureExtractor(nn.Module): + """Construct the featurs from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [HubertGroupNormConvLayer(config, layer_id=0)] + [ + HubertNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [HubertLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class HubertFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Hubert +class HubertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert +class HubertFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert +class HubertEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = HubertAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = HubertFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert +class HubertEncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = HubertAttention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = HubertFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert +class HubertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = HubertPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states[~attention_mask] = 0.0 + + # extend attention_mask + attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if getattr(self.config, "gradient_checkpointing", False) and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert +class HubertEncoderStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = HubertPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens are not attended to + hidden_states[~attention_mask] = 0 + + # extend attention_mask + attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states) + + deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication + if getattr(self.config, "gradient_checkpointing", False) and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class HubertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = HubertConfig + base_model_prefix = "hubert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + +HUBERT_START_DOCSTRING = r""" + Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units + `__ by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, + Ruslan Salakhutdinov, Abdelrahman Mohamed. + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch `torch.nn.Module `_ sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.HubertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + + +HUBERT_INPUTS_DOCSTRING = r""" + Args: + input_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Processor` should + be used for padding and conversion into a tensor of type `torch.FloatTensor`. See + :meth:`transformers.Wav2Vec2Processor.__call__` for details. + attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0, + 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + .. warning:: + :obj:`attention_mask` should only be passed if the corresponding processor has + ``config.return_attention_mask == True``. For all models whose processor has + ``config.return_attention_mask == False``, such as `hubert-base + `__, :obj:`attention_mask` should **not** be passed + to avoid degraded performance when doing batched inference. For such models :obj:`input_values` should + simply be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield + slightly different results depending on whether :obj:`input_values` is padded or not. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.", + HUBERT_START_DOCSTRING, +) +class HubertModel(HubertPreTrainedModel): + def __init__(self, config: HubertConfig): + super().__init__(config) + self.config = config + self.feature_extractor = HubertFeatureExtractor(config) + self.feature_projection = HubertFeatureProjection(config) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = HubertEncoderStableLayerNorm(config) + else: + self.encoder = HubertEncoder(config) + + self.init_weights() + + def _mask_hidden_states( + self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None + ): + """ + Masks extracted features along time axis and/or along feature axis according to `SpecAugment + `__ . + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + device=hidden_states.device, + min_masks=2, + ) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + device=hidden_states.device, + ) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + + Returns: + + Example:: + + >>> from transformers import Wav2Vec2Processor, HubertModel + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") + >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 + >>> hidden_states = model(input_values).last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute real output lengths according to convolution formula + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + attention_mask = torch.zeros( + extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[ + (torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1) + ] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + + hidden_states = self.feature_projection(extract_features) + + if mask_time_indices is not None: # apply SpecAugment along time axis with given indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + hidden_states = self._mask_hidden_states(hidden_states) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """, + HUBERT_START_DOCSTRING, +) +class HubertForCTC(HubertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.hubert = HubertModel(config) + self.dropout = nn.Dropout(config.final_dropout) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameter + will not be updated during training. + """ + self.hubert.feature_extractor._freeze_parameters() + + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_length)`, `optional`): + Labels for connectionist temporal classification. Note that ``target_length`` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size - + 1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., + config.vocab_size - 1]``. + + Returns: + + Example:: + + >>> import torch + >>> from transformers import Wav2Vec2Processor, HubertForCTC + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") + >>> model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft") + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_ids = torch.argmax(logits, dim=-1) + + >>> transcription = processor.decode(predicted_ids[0]) + + >>> # compute loss + >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" + + >>> # wrap processor as target processor to encode labels + >>> with processor.as_target_processor(): + ... labels = processor(target_transcription, return_tensors="pt").input_ids + + >>> loss = model(input_values, labels=labels).loss + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.hubert( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0a995c29cbf..c8ce871ea38 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1709,6 +1709,32 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs): requires_backends(load_tf_weights_in_gpt_neo, ["torch"]) +HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class HubertForCTC: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class HubertModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class HubertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + IBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py index 10e7aabba4d..690a9fcf4a8 100644 --- a/src/transformers/utils/modeling_auto_mapping.py +++ b/src/transformers/utils/modeling_auto_mapping.py @@ -263,6 +263,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("Speech2TextConfig", "Speech2TextModel"), ("ViTConfig", "ViTModel"), ("Wav2Vec2Config", "Wav2Vec2Model"), + ("HubertConfig", "HubertModel"), ("M2M100Config", "M2M100Model"), ("ConvBertConfig", "ConvBertModel"), ("LEDConfig", "LEDModel"), diff --git a/tests/test_modeling_hubert.py b/tests/test_modeling_hubert.py new file mode 100644 index 00000000000..90fc004393d --- /dev/null +++ b/tests/test_modeling_hubert.py @@ -0,0 +1,553 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Hubert model. """ + + +import math +import unittest + +from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask +from transformers import is_torch_available +from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, _config_zero_init + + +if is_torch_available(): + import torch + + from transformers import HubertConfig, HubertForCTC, HubertModel, Wav2Vec2Processor + from transformers.models.hubert.modeling_hubert import _compute_mask_indices + + +class HubertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=1024, # speech is longer + is_training=False, + hidden_size=16, + feat_extract_norm="group", + feat_extract_dropout=0.0, + feat_extract_activation="gelu", + conv_dim=(32, 32, 32), + conv_stride=(4, 4, 4), + conv_kernel=(8, 8, 8), + conv_bias=False, + num_conv_pos_embeddings=16, + num_conv_pos_embedding_groups=2, + num_hidden_layers=4, + num_attention_heads=2, + hidden_dropout_prob=0.1, # this is most likely not correctly set yet + intermediate_size=20, + layer_norm_eps=1e-5, + hidden_act="gelu", + initializer_range=0.02, + vocab_size=32, + do_stable_layer_norm=False, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_dropout = feat_extract_dropout + self.feat_extract_activation = feat_extract_activation + self.conv_dim = conv_dim + self.conv_stride = conv_stride + self.conv_kernel = conv_kernel + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_dropout_prob = hidden_dropout_prob + self.intermediate_size = intermediate_size + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.scope = scope + + output_seq_length = self.seq_length + for kernel, stride in zip(self.conv_kernel, self.conv_stride): + output_seq_length = (output_seq_length - (kernel - 1)) / stride + self.output_seq_length = int(math.ceil(output_seq_length)) + self.encoder_seq_length = self.output_seq_length + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) + attention_mask = random_attention_mask([self.batch_size, self.seq_length]) + + config = HubertConfig( + hidden_size=self.hidden_size, + feat_extract_norm=self.feat_extract_norm, + feat_extract_dropout=self.feat_extract_dropout, + feat_extract_activation=self.feat_extract_activation, + conv_dim=self.conv_dim, + conv_stride=self.conv_stride, + conv_kernel=self.conv_kernel, + conv_bias=self.conv_bias, + num_conv_pos_embeddings=self.num_conv_pos_embeddings, + num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + hidden_dropout_prob=self.hidden_dropout_prob, + intermediate_size=self.intermediate_size, + layer_norm_eps=self.layer_norm_eps, + hidden_act=self.hidden_act, + initializer_range=self.initializer_range, + vocab_size=self.vocab_size, + ) + + return config, input_values, attention_mask + + def create_and_check_model(self, config, input_values, attention_mask): + model = HubertModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_values, attention_mask=attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) + ) + + def create_and_check_batch_inference(self, config, input_values, *args): + # test does not pass for models making use of `group_norm` + # check: https://github.com/pytorch/fairseq/issues/3227 + model = HubertModel(config=config) + model.to(torch_device) + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0.0 + + batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state + + for i in range(input_values.shape[0]): + input_slice = input_values[i : i + 1, : input_lengths[i]] + output = model(input_slice).last_hidden_state + + batch_output = batch_outputs[i : i + 1, : output.shape[1]] + self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3)) + + def check_ctc_loss(self, config, input_values, *args): + model = HubertForCTC(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 + + model.config.ctc_loss_reduction = "sum" + sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss + + model.config.ctc_loss_reduction = "mean" + mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss + + self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3) + + def check_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = HubertForCTC(config=config) + model.to(torch_device) + model.train() + + # freeze feature encoder + model.freeze_feature_extractor() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + if max_length_labels[i] < labels.shape[-1]: + # it's important that we make sure that target lenghts are at least + # one shorter than logit lenghts to prevent -inf + labels[i, max_length_labels[i] - 1 :] = -100 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + + def prepare_config_and_inputs_for_common(self): + config, input_values, attention_mask = self.prepare_config_and_inputs() + inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_torch +class HubertModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def setUp(self): + self.model_tester = HubertModelTester(self) + self.config_tester = ConfigTester(self, config_class=HubertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_loss(*config_and_inputs) + + def test_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_training(*config_and_inputs) + + # Hubert has no inputs_embeds + def test_inputs_embeds(self): + pass + + # `input_ids` is renamed to `input_values` + def test_forward_signature(self): + pass + + # Hubert cannot resize token embeddings + # since it has no tokens embeddings + def test_resize_tokens_embeddings(self): + pass + + # Hubert has no inputs_embeds + # and thus the `get_input_embeddings` fn + # is not implemented + def test_model_common_attributes(self): + pass + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + # set layer drop to 0 + model.config.layerdrop = 0.0 + + input_values = inputs_dict["input_values"] + + input_lengths = torch.tensor( + [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device + ) + output_lengths = model._get_feat_extract_output_lengths(input_lengths) + + labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size) + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"]) + inputs_dict["labels"] = labels + + outputs = model(**inputs_dict) + + output = outputs[0] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + attentions = outputs.attentions[0] + + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "quantizer.weight_proj.weight", + ] + if param.requires_grad: + if any([x in name for x in uniform_init_parms]): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight_g is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "weight_v") and module.weight_v is not None: + module.weight_v.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: + module.masked_spec_embed.data.fill_(3) + + @slow + def test_model_from_pretrained(self): + model = HubertModel.from_pretrained("facebook/hubert-base-ls960") + self.assertIsNotNone(model) + + +@require_torch +class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else () + test_pruning = False + test_headmasking = False + test_torchscript = False + + def setUp(self): + self.model_tester = HubertModelTester( + self, conv_stride=(3, 3, 3), feat_extract_norm="layer", do_stable_layer_norm=True + ) + self.config_tester = ConfigTester(self, config_class=HubertConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_batched_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_batch_inference(*config_and_inputs) + + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_loss(*config_and_inputs) + + def test_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_training(*config_and_inputs) + + # Hubert has no inputs_embeds + def test_inputs_embeds(self): + pass + + # `input_ids` is renamed to `input_values` + def test_forward_signature(self): + pass + + # Hubert cannot resize token embeddings + # since it has no tokens embeddings + def test_resize_tokens_embeddings(self): + pass + + # Hubert has no inputs_embeds + # and thus the `get_input_embeddings` fn + # is not implemented + def test_model_common_attributes(self): + pass + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + # set layer drop to 0 + model.config.layerdrop = 0.0 + + input_values = inputs_dict["input_values"] + + input_lengths = torch.tensor( + [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device + ) + output_lengths = model._get_feat_extract_output_lengths(input_lengths) + + labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size) + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"]) + inputs_dict["labels"] = labels + + outputs = model(**inputs_dict) + + output = outputs[0] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + attentions = outputs.attentions[0] + + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = [ + "conv.weight", + "masked_spec_embed", + "quantizer.weight_proj.weight", + ] + if param.requires_grad: + if any([x in name for x in uniform_init_parms]): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # overwrite from test_modeling_common + def _mock_init_weights(self, module): + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(3) + if hasattr(module, "weight_g") and module.weight_g is not None: + module.weight_g.data.fill_(3) + if hasattr(module, "weight_v") and module.weight_v is not None: + module.weight_v.data.fill_(3) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.fill_(3) + if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: + module.masked_spec_embed.data.fill_(3) + + @slow + def test_model_from_pretrained(self): + model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft") + self.assertIsNotNone(model) + + +@require_torch +class HubertUtilsTest(unittest.TestCase): + def test_compute_mask_indices(self): + batch_size = 4 + sequence_length = 60 + mask_prob = 0.5 + mask_length = 1 + + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) + + self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) + + def test_compute_mask_indices_overlap(self): + batch_size = 4 + sequence_length = 80 + mask_prob = 0.5 + mask_length = 4 + + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) + + # because of overlap mask don't have to add up exactly to `mask_prob * sequence_length`, but have to be smaller or equal + for batch_sum in mask.sum(axis=-1): + self.assertTrue(int(batch_sum) <= mask_prob * sequence_length) + + +@require_torch +@require_datasets +@require_soundfile +@slow +class HubertModelIntegrationTest(unittest.TestCase): + def _load_datasamples(self, num_samples): + from datasets import load_dataset + + import soundfile as sf + + ids = [f"1272-141231-000{i}" for i in range(num_samples)] + + # map files to raw + def map_to_array(batch): + speech, _ = sf.read(batch["file"]) + batch["speech"] = speech + return batch + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + + ds = ds.filter(lambda x: x["id"] in ids).sort("id").map(map_to_array) + + return ds["speech"][:num_samples] + + def test_inference_ctc_batched(self): + model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device) + processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True) + + input_speech = self._load_datasamples(2) + + inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + + with torch.no_grad(): + logits = model(input_values, attention_mask=attention_mask).logits + + predicted_ids = torch.argmax(logits, dim=-1) + predicted_trans = processor.batch_decode(predicted_ids) + + EXPECTED_TRANSCRIPTIONS = [ + "a man said to the universe sir i exist", + "sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore", + ] + self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) diff --git a/utils/check_repo.py b/utils/check_repo.py index 3a1bc7baa53..23285c93552 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -119,6 +119,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ "TFRagSequenceForGeneration", "TFRagTokenForGeneration", "Wav2Vec2ForCTC", + "HubertForCTC", "XLMForQuestionAnswering", "XLNetForQuestionAnswering", "SeparableConv1D",