diff --git a/README.md b/README.md index 1b1d727cba1..37f1a71c3c8 100644 --- a/README.md +++ b/README.md @@ -195,6 +195,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[BERT](https://huggingface.co/transformers/model_doc/bert.html)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 1. **[BERT For Sequence Generation](https://huggingface.co/transformers/model_doc/bertgeneration.html)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[BigBird-RoBERTa](https://huggingface.co/transformers/model_doc/bigbird.html)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. +1. **[BigBird-Pegasus](https://huggingface.co/transformers/model_doc/bigbird_pegasus.html)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. 1. **[Blenderbot](https://huggingface.co/transformers/model_doc/blenderbot.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BlenderbotSmall](https://huggingface.co/transformers/model_doc/blenderbot_small.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BORT](https://huggingface.co/transformers/model_doc/bort.html)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry. diff --git a/docs/source/index.rst b/docs/source/index.rst index 9af14e3b539..92eecc75542 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -100,153 +100,156 @@ conversion utilities for the following models: 6. :doc:`BigBird-RoBERTa ` (from Google Research) released with the paper `Big Bird: Transformers for Longer Sequences `__ by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. -7. :doc:`Blenderbot ` (from Facebook) released with the paper `Recipes for building an +7. :doc:`BigBird-Pegasus ` (from Google Research) released with the paper `Big Bird: + Transformers for Longer Sequences `__ by Manzil Zaheer, Guru Guruganesh, Avinava + Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. +8. :doc:`Blenderbot ` (from Facebook) released with the paper `Recipes for building an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -8. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building an +9. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -9. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT - `__ by Adrian de Wynter and Daniel J. Perry. -10. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty +10. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT + `__ by Adrian de Wynter and Daniel J. Perry. +11. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty French Language Model `__ by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. -11. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with +12. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with Span-based Dynamic Convolution `__ by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. -12. :doc:`CPM ` (from Tsinghua University) released with the paper `CPM: A Large-scale Generative +13. :doc:`CPM ` (from Tsinghua University) released with the paper `CPM: A Large-scale Generative Chinese Pre-trained Language Model `__ by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. -13. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language +14. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language Model for Controllable Generation `__ by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. -14. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with +15. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -15. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT +16. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -16. :doc:`DeiT ` (from Facebook) released with the paper `Training data-efficient image transformers & +17. :doc:`DeiT ` (from Facebook) released with the paper `Training data-efficient image transformers & distillation through attention `__ by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. -17. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +18. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -18. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +19. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -19. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +20. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -20. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +21. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -21. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +22. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -22. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +23. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -23. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +24. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -24. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +25. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -25. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo +26. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. -26. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +27. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -27. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +28. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -28. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +29. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -29. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +30. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -30. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity +31. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention `__ by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. -31. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +32. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -32. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +33. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -33. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +34. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -34. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +35. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -35. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +36. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -36. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +37. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -37. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +38. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -38. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +39. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -39. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +40. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -40. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +41. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -41. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +42. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -42. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +43. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -43. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +44. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -44. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +45. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -45. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +46. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -46. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +47. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -47. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +48. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -48. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +49. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -49. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +50. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -50. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +51. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -51. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +52. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -52. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +53. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -53. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +54. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -54. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +55. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -55. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +56. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -275,6 +278,8 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | BigBird | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| BigBirdPegasus | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ | @@ -451,6 +456,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/bertgeneration model_doc/bert_japanese model_doc/bigbird + model_doc/bigbird_pegasus model_doc/blenderbot model_doc/blenderbot_small model_doc/bort diff --git a/docs/source/model_doc/bigbird_pegasus.rst b/docs/source/model_doc/bigbird_pegasus.rst new file mode 100644 index 00000000000..3e0ece9bf6c --- /dev/null +++ b/docs/source/model_doc/bigbird_pegasus.rst @@ -0,0 +1,98 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +BigBirdPegasus +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The BigBird model was proposed in `Big Bird: Transformers for Longer Sequences `__ by +Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, +Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention +based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse +attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it +has been shown that applying sparse, global, and random attention approximates full attention, while being +computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, +BigBird has shown improved performance on various long document NLP tasks, such as question answering and +summarization, compared to BERT or RoBERTa. + +The abstract from the paper is the following: + +*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP. +Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence +length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that +reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and +is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our +theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire +sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to +8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context, +BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also +propose novel applications to genomics data.* + +Tips: + +- For an in-detail explanation on how BigBird's attention works, see `this blog post + `__. +- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using + **original_full** is advised as there is no benefit in using **block_sparse** attention. +- The code currently uses window size of 3 blocks and 2 global blocks. +- Sequence length must be divisible by block size. +- Current implementation supports only **ITC**. +- Current implementation doesn't support **num_random_blocks = 0**. +- BigBirdPegasus uses the `PegasusTokenizer + `__. + +The original code can be found `here `__. + +BigBirdPegasusConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdPegasusConfig + :members: + + +BigBirdPegasusModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdPegasusModel + :members: forward + + +BigBirdPegasusForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdPegasusForConditionalGeneration + :members: forward + + +BigBirdPegasusForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdPegasusForSequenceClassification + :members: forward + + +BigBirdPegasusForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdPegasusForQuestionAnswering + :members: forward + + +BigBirdPegasusForCausalLM +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BigBirdPegasusForCausalLM + :members: forward + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 58ae8ac3873..b1de18192c9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -155,6 +155,10 @@ _import_structure = { "models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], "models.bertweet": ["BertweetTokenizer"], "models.big_bird": ["BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdConfig", "BigBirdTokenizer"], + "models.bigbird_pegasus": [ + "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", + "BigBirdPegasusConfig", + ], "models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"], "models.blenderbot_small": [ "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -543,6 +547,16 @@ if is_torch_available(): "load_tf_weights_in_big_bird", ] ) + _import_structure["models.bigbird_pegasus"].extend( + [ + "BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST", + "BigBirdPegasusForCausalLM", + "BigBirdPegasusForConditionalGeneration", + "BigBirdPegasusForQuestionAnswering", + "BigBirdPegasusForSequenceClassification", + "BigBirdPegasusModel", + ] + ) _import_structure["models.blenderbot"].extend( [ "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1541,6 +1555,7 @@ if TYPE_CHECKING: from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer from .models.bertweet import BertweetTokenizer from .models.big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig, BigBirdTokenizer + from .models.bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer from .models.blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -1885,6 +1900,14 @@ if TYPE_CHECKING: BigBirdPreTrainedModel, load_tf_weights_in_big_bird, ) + from .models.bigbird_pegasus import ( + BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST, + BigBirdPegasusForCausalLM, + BigBirdPegasusForConditionalGeneration, + BigBirdPegasusForQuestionAnswering, + BigBirdPegasusForSequenceClassification, + BigBirdPegasusModel, + ) from .models.blenderbot import ( BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForCausalLM, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 9775339bb45..cbed3a6b4e5 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -635,9 +635,17 @@ class PegasusConverter(SpmConverter): vocab = [ (self.original_tokenizer.pad_token, 0.0), (self.original_tokenizer.eos_token, 0.0), - (self.original_tokenizer.mask_token_sent, 0.0), - (self.original_tokenizer.mask_token, 0.0), ] + + if self.original_tokenizer.mask_token_sent is not None: + vocab += [(self.original_tokenizer.mask_token_sent, 0.0)] + + if ( + self.original_tokenizer.mask_token is not None + and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset + ): + vocab += [(self.original_tokenizer.mask_token, 0.0)] + vocab += [(f"", -100.0) for i in range(2, self.original_tokenizer.offset)] vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]] return vocab diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b1ee27e7257..7fd6d63acdc 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -26,6 +26,7 @@ from . import ( bert_japanese, bertweet, big_bird, + bigbird_pegasus, blenderbot, blenderbot_small, camembert, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index f343348a7c7..e3c78dd3404 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -23,6 +23,10 @@ from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartCo from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from ..bert_generation.configuration_bert_generation import BertGenerationConfig from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig +from ..bigbird_pegasus.configuration_bigbird_pegasus import ( + BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, + BigBirdPegasusConfig, +) from ..blenderbot.configuration_blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig from ..blenderbot_small.configuration_blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -86,6 +90,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( (key, value) for pretrained_map in [ # Add archive maps here + BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -139,6 +144,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("bigbird_pegasus", BigBirdPegasusConfig), ("deit", DeiTConfig), ("luke", LukeConfig), ("gpt_neo", GPTNeoConfig), @@ -198,6 +204,7 @@ CONFIG_MAPPING = OrderedDict( MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("bigbird_pegasus", "BigBirdPegasus"), ("deit", "DeiT"), ("luke", "LUKE"), ("gpt_neo", "GPT Neo"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 22028d173bd..f28b8466676 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -59,6 +59,13 @@ from ..big_bird.modeling_big_bird import ( BigBirdForTokenClassification, BigBirdModel, ) +from ..bigbird_pegasus.modeling_bigbird_pegasus import ( + BigBirdPegasusForCausalLM, + BigBirdPegasusForConditionalGeneration, + BigBirdPegasusForQuestionAnswering, + BigBirdPegasusForSequenceClassification, + BigBirdPegasusModel, +) from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel from ..blenderbot_small.modeling_blenderbot_small import ( BlenderbotSmallForCausalLM, @@ -288,6 +295,7 @@ from .configuration_auto import ( BertConfig, BertGenerationConfig, BigBirdConfig, + BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, CamembertConfig, @@ -344,6 +352,7 @@ logger = logging.get_logger(__name__) MODEL_MAPPING = OrderedDict( [ # Base model mapping + (BigBirdPegasusConfig, BigBirdPegasusModel), (DeiTConfig, DeiTModel), (LukeConfig, LukeModel), (GPTNeoConfig, GPTNeoModel), @@ -439,6 +448,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping + (BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), (GPTNeoConfig, GPTNeoForCausalLM), (BigBirdConfig, BigBirdForMaskedLM), (Speech2TextConfig, Speech2TextForConditionalGeneration), @@ -485,6 +495,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Causal LM mapping + (BigBirdPegasusConfig, BigBirdPegasusForCausalLM), (GPTNeoConfig, GPTNeoForCausalLM), (BigBirdConfig, BigBirdForCausalLM), (CamembertConfig, CamembertForCausalLM), @@ -557,6 +568,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + (BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration), (M2M100Config, M2M100ForConditionalGeneration), (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), @@ -577,6 +589,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( [ # Model for Sequence Classification mapping + (BigBirdPegasusConfig, BigBirdPegasusForSequenceClassification), (BigBirdConfig, BigBirdForSequenceClassification), (ConvBertConfig, ConvBertForSequenceClassification), (LEDConfig, LEDForSequenceClassification), @@ -614,6 +627,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( [ # Model for Question Answering mapping + (BigBirdPegasusConfig, BigBirdPegasusForQuestionAnswering), (BigBirdConfig, BigBirdForQuestionAnswering), (ConvBertConfig, ConvBertForQuestionAnswering), (LEDConfig, LEDForQuestionAnswering), diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 45da61b9913..7acea14b9ee 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -549,6 +549,7 @@ class BigBirdBlockSparseAttention(nn.Module): rsqrt_d = 1 / math.sqrt(attention_head_size) bsz = batch_size + attn_mask_penalty = -10000.0 # generate random attention and corresponding masks np.random.seed(seed) @@ -606,7 +607,7 @@ class BigBirdBlockSparseAttention(nn.Module): first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) first_product = first_product * rsqrt_d - first_product += (1.0 - to_mask) * -10000.0 + first_product += (1.0 - to_mask) * attn_mask_penalty first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len] # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] @@ -658,7 +659,7 @@ class BigBirdBlockSparseAttention(nn.Module): dim=3, ) second_product = second_product * rsqrt_d - second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * -10000.0 + second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty second_attn_weights = F.softmax( second_product, dim=-1 ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] @@ -709,10 +710,10 @@ class BigBirdBlockSparseAttention(nn.Module): last_band_product = last_band_product * rsqrt_d # masking padded tokens - inner_band_product += (1.0 - band_mask) * -10000.0 - first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * -10000.0 - last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * -10000.0 - rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * -10000.0 + inner_band_product += (1.0 - band_mask) * attn_mask_penalty + first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty + last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty # completing attention scores matrix for all q[-2:2] band_product = torch.cat( @@ -792,7 +793,7 @@ class BigBirdBlockSparseAttention(nn.Module): dim=3, ) second_last_product = second_last_product * rsqrt_d - second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * -10000.0 + second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty second_last_attn_weights = F.softmax( second_last_product, dim=-1 ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] @@ -808,7 +809,7 @@ class BigBirdBlockSparseAttention(nn.Module): # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) last_product = last_product * rsqrt_d - last_product += (1.0 - to_mask) * -10000.0 + last_product += (1.0 - to_mask) * attn_mask_penalty last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] diff --git a/src/transformers/models/bigbird_pegasus/__init__.py b/src/transformers/models/bigbird_pegasus/__init__.py new file mode 100644 index 00000000000..270cb757801 --- /dev/null +++ b/src/transformers/models/bigbird_pegasus/__init__.py @@ -0,0 +1,70 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available + + +_import_structure = { + "configuration_bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig"], +} + +if is_torch_available(): + _import_structure["modeling_bigbird_pegasus"] = [ + "BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST", + "BigBirdPegasusForCausalLM", + "BigBirdPegasusForConditionalGeneration", + "BigBirdPegasusForQuestionAnswering", + "BigBirdPegasusForSequenceClassification", + "BigBirdPegasusModel", + "BigBirdPegasusPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig + + if is_torch_available(): + from .modeling_bigbird_pegasus import ( + BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST, + BigBirdPegasusForCausalLM, + BigBirdPegasusForConditionalGeneration, + BigBirdPegasusForQuestionAnswering, + BigBirdPegasusForSequenceClassification, + BigBirdPegasusModel, + BigBirdPegasusPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py new file mode 100644 index 00000000000..49c18a44f8e --- /dev/null +++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright Google Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BigBirdPegasus model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/bigbird-pegasus-large-arxiv": "https://huggingface.co/google/bigbird-pegasus-large-arxiv/resolve/main/config.json", + "google/bigbird-pegasus-large-pubmed": "https://huggingface.co/google/bigbird-pegasus-large-pubmed/resolve/main/config.json", + "google/bigbird-pegasus-large-bigpatent": "https://huggingface.co/google/bigbird-pegasus-large-bigpatent/resolve/main/config.json", + # See all BigBirdPegasus models at https://huggingface.co/models?filter=bigbird_pegasus +} + + +class BigBirdPegasusConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.BigBirdPegasusModel`. It is + used to instantiate an BigBirdPegasus model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the BigBirdPegasus + `google/bigbird-pegasus-large-arxiv `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 96103): + Vocabulary size of the BigBirdPegasus model. Defines the number of different tokens that can be represented + by the :obj:`inputs_ids` passed when calling :class:`~transformers.BigBirdPegasusModel`. + d_model (:obj:`int`, `optional`, defaults to 1024): + Dimension of the layers and the pooler layer. + encoder_layers (:obj:`int`, `optional`, defaults to 16): + Number of encoder layers. + decoder_layers (:obj:`int`, `optional`, defaults to 16): + Number of decoder layers. + encoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (:obj:`int`, `optional`, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096): + Dimension of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu_fast"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu_fast"` and :obj:`"gelu_new"` are supported. + dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (:obj:`int`, `optional`, defaults to 4096): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 1024 or 2048 or 4096). + init_std (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. + decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + attention_type (:obj:`str`, `optional`, defaults to :obj:`"block_sparse"`) + Whether to use block sparse attention (with n complexity) as introduced in paper or original attention + layer (with n^2 complexity) in encoder. Possible values are :obj:`"original_full"` and + :obj:`"block_sparse"`. + use_bias (:obj:`bool`, `optional`, defaults to :obj:`False`) + Whether to use bias in query, key, value. + block_size (:obj:`int`, `optional`, defaults to 64) + Size of each block. Useful only when :obj:`attention_type == "block_sparse"`. + num_random_blocks (:obj:`int`, `optional`, defaults to 3) + Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type == + "block_sparse"`. + scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`) + Whether to rescale embeddings with (hidden_size ** 0.5). + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import BigBirdPegasusModel, BigBirdPegasusConfig + + >>> # Initializing a BigBirdPegasus bigbird-pegasus-base style configuration + >>> configuration = BigBirdPegasusConfig() + + >>> # Initializing a model from the bigbird-pegasus-base style configuration + >>> model = BigBirdPegasusModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "bigbird_pegasus" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=96103, + max_position_embeddings=4096, + encoder_layers=16, + encoder_ffn_dim=4096, + encoder_attention_heads=16, + decoder_layers=16, + decoder_ffn_dim=4096, + decoder_attention_heads=16, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="gelu_fast", + d_model=1024, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + scale_embedding=True, + gradient_checkpointing=False, + pad_token_id=0, + bos_token_id=2, + eos_token_id=1, + attention_type="block_sparse", # only for encoder + block_size=64, + num_random_blocks=3, + use_bias=False, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.gradient_checkpointing = gradient_checkpointing + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + # extra config + self.attention_type = attention_type + self.block_size = block_size + self.num_random_blocks = num_random_blocks + self.use_bias = use_bias + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model + + @property + def attention_probs_dropout_prob(self) -> float: + return self.attention_dropout diff --git a/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py b/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py new file mode 100644 index 00000000000..2d2efdec774 --- /dev/null +++ b/src/transformers/models/bigbird_pegasus/convert_bigbird_pegasus_tf_to_pytorch.py @@ -0,0 +1,171 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from typing import Dict + +import tensorflow as tf +import torch +from tqdm import tqdm + +from transformers import BigBirdPegasusConfig, BigBirdPegasusForConditionalGeneration + + +INIT_COMMON = [ + # tf -> hf + ("/", "."), + ("layer_", "layers."), + ("kernel", "weight"), + ("beta", "bias"), + ("gamma", "weight"), + ("pegasus", "model"), +] +END_COMMON = [ + (".output.dense", ".fc2"), + ("intermediate.LayerNorm", "final_layer_norm"), + ("intermediate.dense", "fc1"), +] + +DECODER_PATTERNS = ( + INIT_COMMON + + [ + ("attention.self.LayerNorm", "self_attn_layer_norm"), + ("attention.output.dense", "self_attn.out_proj"), + ("attention.self", "self_attn"), + ("attention.encdec.LayerNorm", "encoder_attn_layer_norm"), + ("attention.encdec_output.dense", "encoder_attn.out_proj"), + ("attention.encdec", "encoder_attn"), + ("key", "k_proj"), + ("value", "v_proj"), + ("query", "q_proj"), + ("decoder.LayerNorm", "decoder.layernorm_embedding"), + ] + + END_COMMON +) + +REMAINING_PATTERNS = ( + INIT_COMMON + + [ + ("embeddings.word_embeddings", "shared.weight"), + ("embeddings.position_embeddings", "embed_positions.weight"), + ("attention.self.LayerNorm", "self_attn_layer_norm"), + ("attention.output.dense", "self_attn.output"), + ("attention.self", "self_attn.self"), + ("encoder.LayerNorm", "encoder.layernorm_embedding"), + ] + + END_COMMON +) + +KEYS_TO_IGNORE = [ + "encdec/key/bias", + "encdec/query/bias", + "encdec/value/bias", + "self/key/bias", + "self/query/bias", + "self/value/bias", + "encdec_output/dense/bias", + "attention/output/dense/bias", +] + + +def rename_state_dict_key(k, patterns): + for tf_name, hf_name in patterns: + k = k.replace(tf_name, hf_name) + return k + + +def convert_bigbird_pegasus(tf_weights: dict, config_update: dict) -> BigBirdPegasusForConditionalGeneration: + + cfg = BigBirdPegasusConfig(**config_update) + torch_model = BigBirdPegasusForConditionalGeneration(cfg) + state_dict = torch_model.state_dict() + mapping = {} + + # separating decoder weights + decoder_weights = {k: tf_weights[k] for k in tf_weights if k.startswith("pegasus/decoder")} + remaining_weights = {k: tf_weights[k] for k in tf_weights if not k.startswith("pegasus/decoder")} + + for k, v in tqdm(decoder_weights.items(), "tf -> hf conversion"): + conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] + if any(conditions): + continue + patterns = DECODER_PATTERNS + new_k = rename_state_dict_key(k, patterns) + if new_k not in state_dict: + raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") + if any([True if i in k else False for i in ["dense", "query", "key", "value"]]): + v = v.T + mapping[new_k] = torch.from_numpy(v) + assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" + + for k, v in tqdm(remaining_weights.items(), "tf -> hf conversion"): + conditions = [k.endswith(ending) for ending in KEYS_TO_IGNORE] + if any(conditions): + continue + patterns = REMAINING_PATTERNS + new_k = rename_state_dict_key(k, patterns) + if new_k not in state_dict and k != "pegasus/embeddings/position_embeddings": + raise ValueError(f"could not find new key {new_k} in state dict. (converted from {k})") + if any([True if i in k else False for i in ["dense", "query", "key", "value"]]): + v = v.T + mapping[new_k] = torch.from_numpy(v) + if k != "pegasus/embeddings/position_embeddings": + assert v.shape == state_dict[new_k].shape, f"{new_k}, {k}, {v.shape}, {state_dict[new_k].shape}" + + mapping["model.encoder.embed_positions.weight"] = mapping["model.embed_positions.weight"] + mapping["model.decoder.embed_positions.weight"] = mapping.pop("model.embed_positions.weight") + missing, extra = torch_model.load_state_dict(mapping, strict=False) + unexpected_missing = [ + k + for k in missing + if k + not in [ + "final_logits_bias", + "model.encoder.embed_tokens.weight", + "model.decoder.embed_tokens.weight", + "lm_head.weight", + ] + ] + assert unexpected_missing == [], f"no matches found for the following torch keys {unexpected_missing}" + assert extra == [], f"no matches found for the following tf keys {extra}" + return torch_model + + +def get_tf_weights_as_numpy(path) -> Dict: + init_vars = tf.train.list_variables(path) + tf_weights = {} + ignore_name = ["global_step"] + for name, shape in tqdm(init_vars, desc="converting tf checkpoint to dict"): + skip_key = any([pat in name for pat in ignore_name]) + if skip_key: + continue + array = tf.train.load_variable(path, name) + tf_weights[name] = array + return tf_weights + + +def convert_bigbird_pegasus_ckpt_to_pytorch(ckpt_path: str, save_dir: str, config_update: dict): + tf_weights = get_tf_weights_as_numpy(ckpt_path) + torch_model = convert_bigbird_pegasus(tf_weights, config_update) + torch_model.save_pretrained(save_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tf_ckpt_path", type=str, help="passed to tf.train.list_variables") + parser.add_argument("--save_dir", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + config_update = {} + convert_bigbird_pegasus_ckpt_to_pytorch(args.tf_ckpt_path, args.save_dir, config_update=config_update) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py new file mode 100755 index 00000000000..524a9f3484b --- /dev/null +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -0,0 +1,3009 @@ +# coding=utf-8 +# Copyright 2021 Google Research The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BigBirdPegasus model. """ + + +import copy +import math +import random +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqQuestionAnsweringModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_bigbird_pegasus import BigBirdPegasusConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/bigbird-pegasus-large-arxiv" +_CONFIG_FOR_DOC = "BigBirdPegasusConfig" +_TOKENIZER_FOR_DOC = "PegasusTokenizer" + + +BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/bigbird-pegasus-large-arxiv", + "google/bigbird-pegasus-large-pubmed", + "google/bigbird-pegasus-large-bigpatent", + # See all BigBirdPegasus models at https://huggingface.co/models?filter=bigbird_pegasus +] + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings, embedding_dim) + + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + bsz, seq_len = input_ids_shape[:2] + positions = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(positions) + + +# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus +class BigBirdPegasusSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BigBirdPegasusModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = F.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdBlockSparseAttention with BigBird->BigBirdPegasus +class BigBirdPegasusBlockSparseAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + + self.max_seqlen = config.max_position_embeddings + self.seed = seed + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.num_random_blocks = config.num_random_blocks + self.block_size = config.block_size + + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.use_bias) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions=None, + ): + # Currently this `class` can't be used in decoder. + + batch_size, seqlen, _ = hidden_states.size() + to_seq_length = from_seq_length = seqlen + from_block_size = to_block_size = self.block_size + + assert from_seq_length % from_block_size == 0, "Query sided sequence length must be multiple of block size" + assert to_seq_length % to_block_size == 0, "Key/Value sided sequence length must be multiple of block size" + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + context_layer, attention_probs = self.bigbird_block_sparse_attention( + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + self.num_attention_heads, + self.num_random_blocks, + self.attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_length, + to_seq_length, + seed=self.seed, + plan_from_length=None, + plan_num_rand_blocks=None, + output_attentions=output_attentions, + ) + + context_layer = context_layer.contiguous().view(batch_size, from_seq_length, -1) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + return outputs + + @staticmethod + def torch_bmm_nd(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication""" + # faster replacement of torch.einsum ("bhqk,bhkd->bhqd") + return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( + inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) + ) + + @staticmethod + def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): + """Fast nd matrix multiplication with transpose""" + # faster replacement of torch.einsum (bhqd,bhkd->bhqk) + return torch.bmm( + inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) + ).view(inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 2])) + + def bigbird_block_sparse_attention( + self, + query_layer, + key_layer, + value_layer, + band_mask, + from_mask, + to_mask, + from_blocked_mask, + to_blocked_mask, + n_heads, + n_rand_blocks, + attention_head_size, + from_block_size, + to_block_size, + batch_size, + from_seq_len, + to_seq_len, + seed, + plan_from_length, + plan_num_rand_blocks, + output_attentions, + ): + + # BigBirdPegasus block-sparse attention as suggested in paper + + # ITC: + # global tokens: 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # ETC: + # global tokens: extra_globals_tokens + 2 x block_size + # window tokens: 3 x block_size + # random tokens: num_rand_tokens x block_size + + # Note: + # 1) Currently, ETC is not supported. + # 2) Window size is fixed to 3 blocks & it can be changed only by + # changing `block_size`. + # 3) Number of global blocks are fixed (2 blocks here) & global tokens can be + # controlled only by `block_size`. + + # attention is calculated separately for q[0], q[1], q[2:-2], q[-2], q[-1] in order to use special trick of shifting tokens (for calculating sliding attention) + # hence following code can be divided into 5 parts. + + if from_seq_len // from_block_size != to_seq_len // to_block_size: + raise ValueError("Error the number of blocks needs to be same!") + + rsqrt_d = 1 / math.sqrt(attention_head_size) + bsz = batch_size + attn_mask_penalty = -10000.0 + + # generate random attention and corresponding masks + np.random.seed(seed) + if from_seq_len in [1024, 3072, 4096]: # old plans used in paper + rand_attn = [ + self._bigbird_block_rand_mask( + self.max_seqlen, self.max_seqlen, from_block_size, to_block_size, n_rand_blocks, last_idx=1024 + )[: (from_seq_len // from_block_size - 2)] + for _ in range(n_heads) + ] + else: + if plan_from_length is None: + plan_from_length, plan_num_rand_blocks = self._get_rand_attn_plan( + from_seq_len, from_block_size, n_rand_blocks + ) + + rand_attn = self._bigbird_block_rand_mask_with_head( + from_seq_length=from_seq_len, + to_seq_length=to_seq_len, + from_block_size=from_block_size, + to_block_size=to_block_size, + num_heads=n_heads, + plan_from_length=plan_from_length, + plan_num_rand_blocks=plan_num_rand_blocks, + ) + + rand_attn = np.stack(rand_attn, axis=0) + rand_attn = torch.tensor(rand_attn, device=query_layer.device, dtype=torch.long) + rand_attn.unsqueeze_(0) + rand_attn = torch.cat([rand_attn for _ in range(batch_size)], dim=0) + + rand_mask = self._create_rand_mask_from_inputs( + from_blocked_mask, to_blocked_mask, rand_attn, n_heads, n_rand_blocks, bsz, from_seq_len, from_block_size + ) + + blocked_query_matrix = query_layer.view(bsz, n_heads, from_seq_len // from_block_size, from_block_size, -1) + blocked_key_matrix = key_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + blocked_value_matrix = value_layer.view(bsz, n_heads, to_seq_len // to_block_size, to_block_size, -1) + + # preparing block for randn attn + gathered_key = self.torch_gather_b2(blocked_key_matrix, rand_attn) + gathered_key = gathered_key.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + gathered_value = self.torch_gather_b2(blocked_value_matrix, rand_attn) + gathered_value = gathered_value.view( + bsz, n_heads, to_seq_len // to_block_size - 2, n_rand_blocks * to_block_size, -1 + ) # [bsz, n_heads, to_seq_len//to_block_size-2, n_rand_blocks, to_block_size, -1] + + # 1st PART + # 1st block (global block) attention scores + # q[0] x (k[0], k[1], k[2], k[3], k[4] .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + first_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 0], key_layer, ndim=4) + + first_product = first_product * rsqrt_d + first_product += (1.0 - to_mask) * attn_mask_penalty + first_attn_weights = F.softmax(first_product, dim=-1) # [bsz, n_heads, from_block_size, to_seq_len] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + first_context_layer = self.torch_bmm_nd(first_attn_weights, value_layer, ndim=4) + first_context_layer.unsqueeze_(2) + + # 2nd PART + # 2nd block attention scores + # q[1] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> 2nd, 3rd blocks + # global key blocks -> 1st block + + second_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, 1], + blocked_key_matrix[:, :, 2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + second_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, 1], + blocked_value_matrix[:, :, 2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, 0], + ], + dim=2, + ) # [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, 1], second_key_mat, ndim=4) + second_seq_pad = torch.cat( + [ + to_mask[:, :, :, : 3 * to_block_size], + to_mask[:, :, :, -to_block_size:], + first_context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_rand_pad = torch.cat( + [ + first_context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, 0], + ], + dim=3, + ) + second_product = second_product * rsqrt_d + second_product += (1.0 - torch.minimum(second_seq_pad, second_rand_pad)) * attn_mask_penalty + second_attn_weights = F.softmax( + second_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_context_layer = self.torch_bmm_nd(second_attn_weights, second_value_mat, ndim=4) + + second_context_layer.unsqueeze_(2) + + # 3rd PART + # Middle blocks attention scores + # q[-2:2] x (sliding_keys, random_keys, global_keys) + # sliding attn is calculated using special trick of shifting tokens as discussed in paper + # random keys are generated by taking random indices as per `rand_attn` + # global keys -> 1st & last block + + exp_blocked_key_matrix = torch.cat( + [blocked_key_matrix[:, :, 1:-3], blocked_key_matrix[:, :, 2:-2], blocked_key_matrix[:, :, 3:-1]], dim=3 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + exp_blocked_value_matrix = torch.cat( + [blocked_value_matrix[:, :, 1:-3], blocked_value_matrix[:, :, 2:-2], blocked_value_matrix[:, :, 3:-1]], + dim=3, + ) # [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + middle_query_matrix = blocked_query_matrix[:, :, 2:-2] + + # sliding attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [b, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + inner_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, exp_blocked_key_matrix, ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, 3*to_block_size] + inner_band_product = inner_band_product * rsqrt_d + + # randn attention scores for q[-2:2] + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + rand_band_product = self.torch_bmm_nd_transpose(middle_query_matrix, gathered_key[:, :, 1:-1], ndim=5) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] + rand_band_product = rand_band_product * rsqrt_d + + # Including 1st block (since it's global) + first_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + first_band_product = first_band_product * rsqrt_d + + # Including last block (since it's global) + last_band_product = torch.einsum( + "bhlqd,bhkd->bhlqk", middle_query_matrix, blocked_key_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] + last_band_product = last_band_product * rsqrt_d + + # masking padded tokens + inner_band_product += (1.0 - band_mask) * attn_mask_penalty + first_band_product += (1.0 - to_mask[:, :, :, :to_block_size].unsqueeze(3)) * attn_mask_penalty + last_band_product += (1.0 - to_mask[:, :, :, -to_block_size:].unsqueeze(3)) * attn_mask_penalty + rand_band_product += (1.0 - rand_mask[:, :, 1:-1]) * attn_mask_penalty + + # completing attention scores matrix for all q[-2:2] + band_product = torch.cat( + [first_band_product, inner_band_product, rand_band_product, last_band_product], dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # safely doing softmax since attention matrix is completed + attn_weights = F.softmax( + band_product, dim=-1 + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, (5+n_rand_blocks)*to_block_size] + + # contribution of sliding keys + # [bsz, n_heads, m//from_block_size-4, from_block_size, 3*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, 3*to_block_size, -1] + context_layer = self.torch_bmm_nd( + attn_weights[:, :, :, :, to_block_size : 4 * to_block_size], exp_blocked_value_matrix, ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of random keys + # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, n_rand_blocks*to_block_size] x [bsz, n_heads, from_seq_len//from_block_size-4, n_rand_blocks*to_block_size, -1] + context_layer += self.torch_bmm_nd( + attn_weights[:, :, :, :, 4 * to_block_size : -to_block_size], gathered_value[:, :, 1:-1], ndim=5 + ) + # ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # adding contribution of global keys + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, :to_block_size], blocked_value_matrix[:, :, 0] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + context_layer += torch.einsum( + "bhlqk,bhkd->bhlqd", attn_weights[:, :, :, :, -to_block_size:], blocked_value_matrix[:, :, -1] + ) # [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, to_block_size] x [bsz, n_heads, to_block_size, -1] ==> [bsz, n_heads, from_seq_len//from_block_size-4, from_block_size, -1] + + # 4th PART + # last 2nd token attention scores + # q[-2] x (sliding_keys, random_keys, global_keys) + # sliding key blocks -> last 3 blocks + # global key block -> 1st block + # random key block -> based on indices stored in `randn_attn` + + second_last_key_mat = torch.cat( + [ + blocked_key_matrix[:, :, 0], + blocked_key_matrix[:, :, -3], + blocked_key_matrix[:, :, -2], + blocked_key_matrix[:, :, -1], + gathered_key[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+n_random_blocks)*to_block_size, -1] + second_last_value_mat = torch.cat( + [ + blocked_value_matrix[:, :, 0], + blocked_value_matrix[:, :, -3], + blocked_value_matrix[:, :, -2], + blocked_value_matrix[:, :, -1], + gathered_value[:, :, -1], + ], + dim=2, + ) # [bsz, n_heads, (4+r)*to_block_size, -1] + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + second_last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -2], second_last_key_mat, ndim=4) + second_last_seq_pad = torch.cat( + [ + to_mask[:, :, :, :to_block_size], + to_mask[:, :, :, -3 * to_block_size :], + context_layer.new_ones([bsz, 1, 1, n_rand_blocks * to_block_size]), + ], + dim=3, + ) + second_last_rand_pad = torch.cat( + [ + context_layer.new_ones([bsz, n_heads, from_block_size, 4 * to_block_size]), + rand_mask[:, :, -1], + ], + dim=3, + ) + second_last_product = second_last_product * rsqrt_d + second_last_product += (1.0 - torch.minimum(second_last_seq_pad, second_last_rand_pad)) * attn_mask_penalty + second_last_attn_weights = F.softmax( + second_last_product, dim=-1 + ) # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] + + # [bsz, n_heads, from_block_size, (4+n_rand_blocks)*to_block_size] x [bsz, n_heads, (4+n_rand_blocks)*to_block_size, -1] ==> [bsz, n_heads, from_block_size, -1] + second_last_context_layer = self.torch_bmm_nd(second_last_attn_weights, second_last_value_mat, ndim=4) + second_last_context_layer.unsqueeze_(2) + + # 5th PART + # last block (global) attention scores + # q[-1] x (k[0], k[1], k[2], k[3], .... ) + + # [bsz, n_heads, from_block_size, -1] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, to_seq_len] + last_product = self.torch_bmm_nd_transpose(blocked_query_matrix[:, :, -1], key_layer, ndim=4) + last_product = last_product * rsqrt_d + last_product += (1.0 - to_mask) * attn_mask_penalty + last_attn_weights = F.softmax(last_product, dim=-1) # [bsz, n_heads, from_block_size, n] + + # [bsz, n_heads, from_block_size, to_seq_len] x [bsz, n_heads, to_seq_len, -1] ==> [bsz, n_heads, from_block_size, -1] + last_context_layer = self.torch_bmm_nd(last_attn_weights, value_layer, ndim=4) + last_context_layer.unsqueeze_(2) + + # combining representations of all tokens + context_layer = torch.cat( + [first_context_layer, second_context_layer, context_layer, second_last_context_layer, last_context_layer], + dim=2, + ) + context_layer = context_layer.view((bsz, n_heads, from_seq_len, -1)) * from_mask + context_layer = torch.transpose(context_layer, 1, 2) + + # this is just for visualizing; forward pass doesn't depend on following code + if output_attentions: + # TODO(PVP): need to verify if below code is correct + attention_probs = torch.zeros( + bsz, n_heads, from_seq_len, to_seq_len, dtype=torch.float, device=context_layer.device + ) + + # 1st query block + # corresponding to `first_context_layer` + attention_probs[:, :, :from_block_size, :] = first_attn_weights # all keys global + + # 2nd query block + # corresponding to `second_context_layer` + attention_probs[:, :, from_block_size : 2 * from_block_size, : 3 * to_block_size] = second_attn_weights[ + :, :, :, : 3 * to_block_size + ] # 1st three key blocks (global + sliding) + attention_probs[:, :, from_block_size : 2 * from_block_size, -to_block_size:] = second_attn_weights[ + :, :, :, 3 * to_block_size : 4 * to_block_size + ] # last key block (global) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, 1, :, i2[0]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Middle query blocks + # corresponding to `context_layer` + # sliding keys + for q_idx in range(from_seq_len // from_block_size - 4): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + )[:, :, 2:-2, :, 1:-1, :] + right_slice = attn_weights[:, :, q_idx, :, to_block_size : 4 * to_block_size] + attn_probs_view[:, :, q_idx, :, q_idx : q_idx + 3, :] = right_slice.view( + bsz, n_heads, from_block_size, 3, to_block_size + ) # inner_band_product + # global keys (corresponding to 1st key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, :to_block_size] = attn_weights[ + :, :, :, :, :to_block_size + ].view( + bsz, n_heads, -1, to_block_size + ) # first_band_product + # global keys (corresponding to last key block) + attention_probs[:, :, 2 * from_block_size : -2 * from_block_size, -to_block_size:] = attn_weights[ + :, :, :, :, -to_block_size: + ].view( + bsz, n_heads, -1, to_block_size + ) # last_band_product + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + for q_idx in range(1, len(i2) - 1): + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[q_idx - 1, :, 4 * to_block_size : -to_block_size] + attn_probs_view[p1, p2, q_idx + 1, :, i2[q_idx]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # Second-last query block + # corresponding to `second_last_context_layer` + attention_probs[:, :, -2 * from_block_size : -from_block_size, :to_block_size] = second_last_attn_weights[ + :, :, :, :to_block_size + ] # 1st key block (global) + attention_probs[ + :, :, -2 * from_block_size : -from_block_size, -3 * to_block_size : + ] = second_last_attn_weights[ + :, :, :, to_block_size : 4 * to_block_size + ] # last three blocks (global + sliding) + # random keys + for p1, i1, w1 in zip(range(bsz), rand_attn, second_last_attn_weights): + # p1, i1, w1 corresponds to batch_dim i.e. following operation is done for each sequence in batch + for p2, i2, w2 in zip(range(n_heads), i1, w1): + # p2, i2, w2 corresponds to head_dim i.e. following operation is done for each heads + attn_probs_view = attention_probs.view( + bsz, + n_heads, + from_seq_len // from_block_size, + from_block_size, + to_seq_len // to_block_size, + to_block_size, + ) + right_slice = w2[:, 4 * to_block_size :] + attn_probs_view[p1, p2, -2, :, i2[-1]] = right_slice.view( + from_block_size, n_rand_blocks, to_block_size + ) + + # last query block + # corresponding to `last_context_layer` + attention_probs[:, :, -from_block_size:, :] = last_attn_weights # all keys global + + else: + attention_probs = None + + return context_layer, attention_probs + + @staticmethod + def torch_gather_b2(params, indices): + # this operation is equivalent to tf.gather when batch_dims=2 + + if params.shape[:2] != indices.shape[:2]: + raise ValueError( + f"Make sure that the first two dimensions of params and indices are identical, \ + but they are params: {params.shape[:2]} vs. indices: {params.shape[:2]}" + ) + num_indices_to_gather = indices.shape[-2] * indices.shape[-1] + num_indices_to_pick_from = params.shape[2] + + indices_shift = ( + torch.arange(indices.shape[0] * indices.shape[1] * num_indices_to_gather, device=indices.device) + // num_indices_to_gather + * num_indices_to_pick_from + ) + + flattened_indices = indices.view(-1) + indices_shift + flattened_params = params.reshape(-1, params.shape[-2], params.shape[-1]) + + out_flattened = flattened_params.index_select(0, flattened_indices) + + out = out_flattened.reshape(params.shape[:2] + (num_indices_to_gather,) + params.shape[3:]) + return out + + @staticmethod + def _create_rand_mask_from_inputs( + from_blocked_mask, + to_blocked_mask, + rand_attn, + num_attention_heads, + num_rand_blocks, + batch_size, + from_seq_length, + from_block_size, + ): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + rand_attn: [batch_size, num_attention_heads, + from_seq_length//from_block_size-2, num_rand_blocks] + num_attention_heads: int. Number of attention heads. + num_rand_blocks: int. Number of random chunks per row. + batch_size: int. Batch size for computation. + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + + Returns: + float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, + from_block_size, num_rand_blocks*to_block_size]. + """ + num_windows = from_seq_length // from_block_size - 2 + rand_mask = torch.stack([p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]) + rand_mask = rand_mask.view(batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size) + rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) + return rand_mask + + @staticmethod + def _get_rand_attn_plan(from_seq_length, from_block_size, num_rand_blocks): + """ + Gives the plan of where to put random attention. + + Args: + from_seq_length: int. length of from sequence. + from_block_size: int. size of block in from sequence. + num_rand_blocks: int. Number of random chunks per row. + + Returns: + plan_from_length: ending location of from block plan_num_rand_blocks: number of random ending location for + each block + """ + + plan_from_length = [] + plan_num_rand_blocks = [] + if (2 * num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((2 * num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(0) + elif (num_rand_blocks + 5) < (from_seq_length // from_block_size): + plan_from_length.append(int((num_rand_blocks + 5) * from_block_size)) + plan_num_rand_blocks.append(num_rand_blocks // 2) + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks - (num_rand_blocks // 2)) + else: + plan_from_length.append(from_seq_length) + plan_num_rand_blocks.append(num_rand_blocks) + + return plan_from_length, plan_num_rand_blocks + + @staticmethod + def _bigbird_block_rand_mask( + from_seq_length, to_seq_length, from_block_size, to_block_size, num_rand_blocks, last_idx=-1 + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_rand_blocks: int. Number of random chunks per row. + last_idx: if -1 then num_rand_blocks blocks chosen anywhere in to sequence, + if positive then num_rand_blocks blocks chosen only up to last_idx. + + Returns: + adjacency list of size from_seq_length//from_block_size-2 by num_rand_blocks + """ + # using this method when from_seq_length in [1024, 3072, 4096] + + assert ( + from_seq_length // from_block_size == to_seq_length // to_block_size + ), "Error the number of blocks needs to be same!" + + rand_attn = np.zeros((from_seq_length // from_block_size - 2, num_rand_blocks), dtype=np.int32) + middle_seq = np.arange(1, to_seq_length // to_block_size - 1, dtype=np.int32) + last = to_seq_length // to_block_size - 1 + if last_idx > (2 * to_block_size): + last = (last_idx // to_block_size) - 1 + + r = num_rand_blocks # shorthand + for i in range(1, from_seq_length // from_block_size - 1): + start = i - 2 + end = i + if i == 1: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[2:last])[:r] + elif i == 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[3:last])[:r] + elif i == from_seq_length // from_block_size - 3: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -3: should have been sliced till last-3 + elif i == from_seq_length // from_block_size - 2: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:last])[:r] + # Missing -4: should have been sliced till last-4 + else: + if start > last: + start = last + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + elif (end + 1) == last: + rand_attn[i - 1, :] = np.random.permutation(middle_seq[:start])[:r] + else: + rand_attn[i - 1, :] = np.random.permutation( + np.concatenate((middle_seq[:start], middle_seq[end + 1 : last])) + )[:r] + return rand_attn + + def _bigbird_block_rand_mask_with_head( + self, + from_seq_length, + to_seq_length, + from_block_size, + to_block_size, + num_heads, + plan_from_length, + plan_num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_top=1, + global_block_bottom=1, + global_block_left=1, + global_block_right=1, + ): + """ + Create adjacency list of random attention. + + Args: + from_seq_length: int. length of from sequence. + to_seq_length: int. length of to sequence. + from_block_size: int. size of block in from sequence. + to_block_size: int. size of block in to sequence. + num_heads: int. total number of heads. + plan_from_length: list. plan from length where num_random_blocks are choosen from. + plan_num_rand_blocks: list. number of rand blocks within the plan. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_top: int. number of blocks at the top. + global_block_bottom: int. number of blocks at the bottom. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + adjacency list of size num_head where each element is of size from_seq_length//from_block_size-2 by + num_rand_blocks + """ + # using this method when from_seq_length not in [1024, 3072, 4096] + + assert ( + from_seq_length // from_block_size == to_seq_length // to_block_size + ), "Error the number of blocks needs to be same!" + + assert from_seq_length in plan_from_length, "Error from sequence length not in plan!" + + # Total number of blocks in the mmask + num_blocks = from_seq_length // from_block_size + # Number of blocks per plan + plan_block_length = np.array(plan_from_length) // from_block_size + # till when to follow plan + max_plan_idx = plan_from_length.index(from_seq_length) + # Random Attention adjacency list + rand_attn = [ + np.zeros((num_blocks, np.sum(plan_num_rand_blocks[: max_plan_idx + 1])), dtype=np.int32) + for i in range(num_heads) + ] + + # We will go iteratively over the plan blocks and pick random number of + # Attention blocks from the legally allowed blocks + for plan_idx in range(max_plan_idx + 1): + rnd_r_cnt = 0 + if plan_idx > 0: + # set the row for all from_blocks starting from 0 to + # plan_block_length[plan_idx-1] + # column indx start fromm plan_block_length[plan_idx-1] and ends at + # plan_block_length[plan_idx] + if plan_num_rand_blocks[plan_idx] > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + for blk_rw_idx in range(global_block_top, plan_block_length[plan_idx - 1]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=plan_block_length[plan_idx - 1], + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for pl_id in range(plan_idx): + if plan_num_rand_blocks[pl_id] == 0: + continue + for blk_rw_idx in range(plan_block_length[plan_idx - 1], plan_block_length[plan_idx]): + rnd_r_cnt = 0 + to_start_block_id = 0 + if pl_id > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:pl_id])) + to_start_block_id = plan_block_length[pl_id - 1] + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: pl_id + 1])) + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[pl_id], + num_rand_blocks=plan_num_rand_blocks[pl_id], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + if plan_num_rand_blocks[plan_idx] == 0: + continue + curr_r_cnt = int(np.sum(plan_num_rand_blocks[: plan_idx + 1])) + from_start_block_id = global_block_top + to_start_block_id = 0 + if plan_idx > 0: + rnd_r_cnt = int(np.sum(plan_num_rand_blocks[:plan_idx])) + from_start_block_id = plan_block_length[plan_idx - 1] + to_start_block_id = plan_block_length[plan_idx - 1] + + for blk_rw_idx in range(from_start_block_id, plan_block_length[plan_idx]): + for h in range(num_heads): + rand_attn[h][blk_rw_idx, rnd_r_cnt:curr_r_cnt] = self._get_single_block_row_attention( + block_id=blk_rw_idx, + to_start_block_id=to_start_block_id, + to_end_block_id=plan_block_length[plan_idx], + num_rand_blocks=plan_num_rand_blocks[plan_idx], + window_block_left=window_block_left, + window_block_right=window_block_right, + global_block_left=global_block_left, + global_block_right=global_block_right, + ) + + for nh in range(num_heads): + rand_attn[nh] = rand_attn[nh][global_block_top : num_blocks - global_block_bottom, :] + + return rand_attn + + @staticmethod + def _get_single_block_row_attention( + block_id, + to_start_block_id, + to_end_block_id, + num_rand_blocks, + window_block_left=1, + window_block_right=1, + global_block_left=1, + global_block_right=1, + ): + """ + For a single row block get random row attention. + + Args: + block_id: int. block id of row. + to_start_block_id: int. random attention column start id. + to_end_block_id: int. random attention column end id. + num_rand_blocks: int. number of random blocks to be selected. + window_block_left: int. number of blocks of window to left of a block. + window_block_right: int. number of blocks of window to right of a block. + global_block_left: int. Number of blocks globally used to the left. + global_block_right: int. Number of blocks globally used to the right. + + Returns: + row containing the random attention vector of size num_rand_blocks. + """ + # list of to_blocks from which to choose random attention + to_block_list = np.arange(to_start_block_id, to_end_block_id, dtype=np.int32) + # permute the blocks + perm_block = np.random.permutation(to_block_list) + + # illegal blocks for the current block id, using window + illegal_blocks = list(range(block_id - window_block_left, block_id + window_block_right + 1)) + + # Add blocks at the start and at the end + illegal_blocks.extend(list(range(global_block_left))) + illegal_blocks.extend(list(range(to_end_block_id - global_block_right, to_end_block_id))) + + # The second from_block cannot choose random attention on second last to_block + if block_id == 1: + illegal_blocks.append(to_end_block_id - 2) + + # The second last from_block cannot choose random attention on second to_block + if block_id == to_end_block_id - 2: + illegal_blocks.append(1) + + selected_random_blokcs = [] + + for i in range(to_end_block_id - to_start_block_id): + if perm_block[i] not in illegal_blocks: + selected_random_blokcs.append(perm_block[i]) + if len(selected_random_blokcs) == num_rand_blocks: + break + return np.array(selected_random_blokcs, dtype=np.int32) + + +class BigBirdPegasusEncoderAttention(nn.Module): + def __init__(self, config, seed=None): + super().__init__() + self.config = config + self.seed = seed + + self.attention_type = config.attention_type + + if self.attention_type == "original_full": + self.self = BigBirdPegasusSelfAttention(config) + elif self.attention_type == "block_sparse": + self.self = BigBirdPegasusBlockSparseAttention(config, seed) + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.config.attention_type}" + ) + + self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + + self.attention_type = value + if value == "original_full": + # copy all weights to new full attention class + attn_weights = BigBirdPegasusSelfAttention(self.config) + else: + # copy all weights to new sparse attention class + attn_weights = BigBirdPegasusBlockSparseAttention(self.config, self.seed) + + attn_weights.query = self.self.query + attn_weights.value = self.self.value + attn_weights.key = self.self.key + self.self = attn_weights + self.attention_type = value + + if not self.training: + self.self.eval() + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + past_key_value=None, + output_attentions=False, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + ): + + if self.config.attention_type == "original_full": + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + else: + self_outputs = self.self( + hidden_states, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, output_attentions + ) + + attention_output = self.output(self_outputs[0]) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BigBirdPegasusDecoder +class BigBirdPegasusDecoderAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_heads, + ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class BigBirdPegasusEncoderLayer(nn.Module): + def __init__(self, config: BigBirdPegasusConfig, seed=None): + super().__init__() + self.attention_type = config.attention_type + self.embed_dim = config.d_model + self.self_attn = BigBirdPegasusEncoderAttention(config, seed=seed) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_head_mask: torch.Tensor, + band_mask=None, + from_mask=None, + to_mask=None, + from_blocked_mask=None, + to_blocked_mask=None, + output_attentions: bool = False, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + self_attention_outputs = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=from_blocked_mask, + to_blocked_mask=to_blocked_mask, + ) + hidden_states = self_attention_outputs[0] + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attention_outputs[1],) + + return outputs + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + self.self_attn.set_attention_type(value) + + +class BigBirdPegasusDecoderLayer(nn.Module): + def __init__(self, config: BigBirdPegasusConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = BigBirdPegasusDecoderAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=config.use_bias, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = BigBirdPegasusDecoderAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + bias=config.use_bias, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->BigBirdPegasus +class BigBirdPegasusClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class BigBirdPegasusPreTrainedModel(PreTrainedModel): + config_class = BigBirdPegasusConfig + base_model_prefix = "model" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +BIGBIRD_PEGASUS_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings + etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.BigBirdPegasusConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +BIGBIRD_PEGASUS_GENERATION_EXAMPLE = r""" + Summarization example:: + + >>> from transformers import PegasusTokenizer, BigBirdPegasusForConditionalGeneration, BigBirdPegasusConfig + + >>> model = BigBirdPegasusForConditionalGeneration.from_pretrained('bigbird-pegasus-large-arxiv') + >>> tokenizer = PegasusTokenizer.from_pretrained('bigbird-pegasus-large-arxiv') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=4096, return_tensors='pt', truncation=True) + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) + >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) +""" + +BIGBIRD_PEGASUS_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the :obj:`input_ids` to the right, following the paper. + decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read + :func:`modeling_bigbird_pegasus._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the + paper `__ for more information on the default strategy. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + +BIGBIRD_PEGASUS_STANDALONE_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + :class:`BigBirdPegasusEncoderLayer`. + + Args: + config: BigBirdPegasusConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + + self.attention_type = config.attention_type + self.block_size = config.block_size + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList([BigBirdPegasusEncoderLayer(config, seed=i) for i in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.PegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + embed_pos = self.embed_positions(input_shape) + + hidden_states = inputs_embeds + embed_pos + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=hidden_states.device) + attention_mask = attention_mask.long() + + # in order to use block_sparse attention, sequence_length has to be at least + # bigger than all global attentions: 2 * block_size + # + sliding tokens: 3 * block_size + # + random tokens: 2 * num_random_blocks * block_size + max_tokens_to_attend = (5 + 2 * self.config.num_random_blocks) * self.config.block_size + if self.attention_type == "block_sparse" and input_shape[1] <= max_tokens_to_attend: + # change attention_type from block_sparse to original_full + sequence_length = input_shape[1] + logger.warning( + "Attention type 'block_sparse' is not possible if sequence_length: " + f"{sequence_length} <= num global tokens: 2 * config.block_size " + "+ min. num sliding tokens: 3 * config.block_size " + "+ config.num_random_blocks * config.block_size " + "+ additional buffer: config.num_random_blocks * config.block_size " + f"= {max_tokens_to_attend} with config.block_size " + f"= {self.config.block_size}, config.num_random_blocks " + f"= {self.config.num_random_blocks}." + "Changing attention type to 'original_full'..." + ) + self.set_attention_type("original_full") + + if self.attention_type == "block_sparse": + padding_len, hidden_states, attention_mask = self._pad_to_block_size(hidden_states, attention_mask) + else: + padding_len = 0 + + # expand attention_mask + if self.attention_type == "original_full": + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + blocked_encoder_mask = band_mask = from_mask = to_mask = None + elif self.attention_type == "block_sparse": + blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn( + attention_mask, self.block_size + ) + attention_mask = None + else: + raise ValueError( + f"attention_type can either be original_full or block_sparse, but is {self.attention_type}" + ) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + blocked_encoder_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_encoder_mask, + to_blocked_mask=blocked_encoder_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layernorm_embedding(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if padding_len > 0: + # unpad `sequence_output` because the calling function is expecting a length == input_ids.size(1) + hidden_states = hidden_states[:, :-padding_len] + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + + self.encoder_o = hidden_states + + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + def set_attention_type(self, value: str): + if value not in ["original_full", "block_sparse"]: + raise ValueError( + f"attention_type can only be set to either 'original_full' or 'block_sparse', but is {value}" + ) + # attention type is already correctly set + if value == self.attention_type: + return + self.attention_type = value + for layer in self.layers: + layer.set_attention_type(value) + + @staticmethod # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdModel.create_masks_for_block_sparse_attn + def create_masks_for_block_sparse_attn(attention_mask: torch.Tensor, block_size: int): + + batch_size, seq_length = attention_mask.size() + assert ( + seq_length % block_size == 0 + ), f"Sequence length must be multiple of block size, but sequence length is {seq_length}, while block size is {block_size}." + + def create_band_mask_from_inputs(from_blocked_mask, to_blocked_mask): + """ + Create 3D attention mask from a 2D tensor mask. + + Args: + from_blocked_mask: 2D Tensor of shape [batch_size, + from_seq_length//from_block_size, from_block_size]. + to_blocked_mask: int32 Tensor of shape [batch_size, + to_seq_length//to_block_size, to_block_size]. + + Returns: + float Tensor of shape [batch_size, 1, from_seq_length//from_block_size-4, from_block_size, + 3*to_block_size]. + """ + exp_blocked_to_pad = torch.cat( + [to_blocked_mask[:, 1:-3], to_blocked_mask[:, 2:-2], to_blocked_mask[:, 3:-1]], dim=2 + ) + band_mask = torch.einsum("blq,blk->blqk", from_blocked_mask[:, 2:-2], exp_blocked_to_pad) + band_mask.unsqueeze_(1) + return band_mask + + blocked_encoder_mask = attention_mask.view(batch_size, seq_length // block_size, block_size) + band_mask = create_band_mask_from_inputs(blocked_encoder_mask, blocked_encoder_mask) + + from_mask = attention_mask.view(batch_size, 1, seq_length, 1) + to_mask = attention_mask.view(batch_size, 1, 1, seq_length) + + return blocked_encoder_mask, band_mask, from_mask, to_mask + + def _pad_to_block_size(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): + """A helper function to pad tokens and mask to work with implementation of BigBird block-sparse attention.""" + # padding + block_size = self.config.block_size + batch_size, seq_len = hidden_states.shape[:2] + + padding_len = (block_size - seq_len % block_size) % block_size + if padding_len > 0: + logger.info( + f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of " + f"`config.block_size`: {block_size}" + ) + pad_id = self.config.pad_token_id + device = hidden_states.device + input_ids_padding = torch.ones((batch_size, padding_len), dtype=torch.long, device=device) * pad_id + inputs_embeds_padding = self.embed_tokens(input_ids_padding) + hidden_states = torch.cat([hidden_states, inputs_embeds_padding], dim=-2) + + attention_mask = F.pad(attention_mask, (0, padding_len), value=0) # no attention on the padding tokens + + return padding_len, hidden_states, attention_mask + + +class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a + :class:`BigBirdPegasusDecoderLayer` + + Args: + config: BigBirdPegasusConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.init_weights() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(self.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.BigBirdPegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing + cross-attention on hidden heads. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # embed positions + positions = self.embed_positions(input_shape, past_key_values_length) + + hidden_states = inputs_embeds + positions + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layernorm_embedding(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare BigBirdPegasus Model outputting raw hidden-states without any specific head on top.", + BIGBIRD_PEGASUS_START_DOCSTRING, +) +# Copied from transformers.models.bart.modeling_bart.BartModel with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS +class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): + def __init__(self, config: BigBirdPegasusConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) + + self.encoder = BigBirdPegasusEncoder(config, self.shared) + self.decoder = BigBirdPegasusDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + # different to other models, BigBirdPegasus automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right( + input_ids, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The BigBirdPegasus Model with a language modeling head. Can be used for summarization.", + BIGBIRD_PEGASUS_START_DOCSTRING, +) +# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS +class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] + + def __init__(self, config: BigBirdPegasusConfig): + super().__init__(config) + self.model = BigBirdPegasusModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens) + self._resize_final_logits_bias(new_num_tokens) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None: + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + + +@add_start_docstrings( + """ + BigBirdPegasus model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. + for GLUE tasks. + """, + BIGBIRD_PEGASUS_START_DOCSTRING, +) +# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS +class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): + def __init__(self, config: BigBirdPegasusConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = BigBirdPegasusModel(config) + self.classification_head = BigBirdPegasusClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + self.model._init_weights(self.classification_head.dense) + self.model._init_weights(self.classification_head.out_proj) + + @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id) + + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + if self.config.num_labels == 1: + # regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings( + """ + BigBirdPegasus Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + BIGBIRD_PEGASUS_START_DOCSTRING, +) +# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS +class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.model = BigBirdPegasusModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.model._init_weights(self.qa_outputs) + + @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Seq2SeqQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + start_positions=None, + end_positions=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if start_positions is not None and end_positions is not None: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return Seq2SeqQuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# Copied from transformers.models.pegasus.modeling_pegasus.PegasusDecoderWrapper with Pegasus->BigBirdPegasus +class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the :class:`~transformers.EncoderDecoderModel` framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = BigBirdPegasusDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +# Copied from transformers.models.pegasus.modeling_pegasus.PegasusForCausalLM with Pegasus->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv" +class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): + def __init__(self, config): + super().__init__(config) + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + self.model = BigBirdPegasusDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.init_weights() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.BigBirdPegasusTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., + config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., + config.vocab_size]``. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + + Returns: + + Example:: + + >>> from transformers import BigBirdPegasusTokenizer, BigBirdPegasusForCausalLM + + >>> tokenizer = BigBirdPegasusTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv") + >>> model = BigBirdPegasusForCausalLM.from_pretrained("google/bigbird-pegasus-large-arxiv", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py index 7ced5672548..74671c98e3d 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus.py +++ b/src/transformers/models/pegasus/tokenization_pegasus.py @@ -80,7 +80,6 @@ class PegasusTokenizer(PreTrainedTokenizer): """ vocab_files_names = VOCAB_FILES_NAMES - offset = 103 # entries 2 - 104 are only used for pretraining vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES @@ -95,8 +94,11 @@ class PegasusTokenizer(PreTrainedTokenizer): mask_token="", mask_token_sent="", additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining **kwargs ): + self.offset = offset + if additional_special_tokens is not None: assert isinstance( additional_special_tokens, list @@ -104,7 +106,7 @@ class PegasusTokenizer(PreTrainedTokenizer): additional_special_tokens_extended = ( ([mask_token_sent] + additional_special_tokens) - if mask_token_sent not in additional_special_tokens + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None else additional_special_tokens ) # fill additional tokens with ..., in case not all additional tokens are already taken @@ -118,7 +120,7 @@ class PegasusTokenizer(PreTrainedTokenizer): ) additional_special_tokens = additional_special_tokens_extended else: - additional_special_tokens = [mask_token_sent] + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] additional_special_tokens += [f"" for i in range(2, self.offset)] super().__init__( @@ -127,24 +129,34 @@ class PegasusTokenizer(PreTrainedTokenizer): mask_token=mask_token, pad_token=pad_token, mask_token_sent=mask_token_sent, + offset=offset, additional_special_tokens=additional_special_tokens, **kwargs, ) + self.mask_token_sent = mask_token_sent self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(vocab_file) - self.mask_token_sent = mask_token_sent # add special tokens to encoder dict self.encoder: Dict[int, str] = { 0: self.pad_token, 1: self.eos_token, - 2: self.mask_token_sent, - 3: self.mask_token, } - # entries 2-104 are only used for pretraining and called , , unk_2, ...unk_102 - # mask_token_sent is already added to list -> so start at 1 - self.encoder.update({i + 3: additional_special_tokens[i] for i in range(1, self.offset - 1)}) + + if self.mask_token_sent is not None: + self.encoder.update( + { + 2: self.mask_token_sent, + 3: self.mask_token, + } + ) + + if self.offset > 0: + # entries 2-104 are only used for pretraining and called , , unk_2, ...unk_102 + # mask_token_sent is already added to list -> so start at 1 + self.encoder.update({i + 3: additional_special_tokens[i] for i in range(1, self.offset - 1)}) + self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()} @property @@ -206,10 +218,6 @@ class PegasusTokenizer(PreTrainedTokenizer): all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp all_special_ids.remove(self.unk_token_id) # is only sometimes special - assert all_special_ids == set( - range(len(self.additional_special_tokens) + 3) - ), f"There should be 3 special tokens: mask_token, pad_token, and eos_token + {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}" - return [1 if x in all_special_ids else 0 for x in seq] def get_special_tokens_mask( diff --git a/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/src/transformers/models/pegasus/tokenization_pegasus_fast.py index 08bd4719333..4ca8018c5e4 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus_fast.py +++ b/src/transformers/models/pegasus/tokenization_pegasus_fast.py @@ -90,7 +90,6 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): `__ that uses the tokens 2 - 104 only for pretraining """ - offset = 103 # entries 2-104 are only used for pretraining vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES @@ -107,8 +106,11 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): mask_token="", mask_token_sent="", additional_special_tokens=None, + offset=103, # entries 2 - 104 are only used for pretraining **kwargs ): + self.offset = offset + if additional_special_tokens is not None: assert isinstance( additional_special_tokens, list @@ -116,7 +118,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): additional_special_tokens_extended = ( ([mask_token_sent] + additional_special_tokens) - if mask_token_sent not in additional_special_tokens + if mask_token_sent not in additional_special_tokens and mask_token_sent is not None else additional_special_tokens ) # fill additional tokens with ..., in case not all additional tokens are already taken @@ -130,7 +132,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): ) additional_special_tokens = additional_special_tokens_extended else: - additional_special_tokens = [mask_token_sent] + additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else [] additional_special_tokens += [f"" for i in range(2, self.offset)] super().__init__( @@ -141,10 +143,10 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): unk_token=unk_token, mask_token=mask_token, mask_token_sent=mask_token_sent, + offset=offset, additional_special_tokens=additional_special_tokens, **kwargs, ) - self.vocab_file = vocab_file def _special_token_mask(self, seq): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 47c80380a83..b4b160dbe3a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -721,6 +721,50 @@ def load_tf_weights_in_big_bird(*args, **kwargs): requires_backends(load_tf_weights_in_big_bird, ["torch"]) +BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class BigBirdPegasusForCausalLM: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BigBirdPegasusForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BigBirdPegasusForQuestionAnswering: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BigBirdPegasusForSequenceClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BigBirdPegasusModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py index 0a05ac24d79..ac4b8756fea 100644 --- a/src/transformers/utils/modeling_auto_mapping.py +++ b/src/transformers/utils/modeling_auto_mapping.py @@ -6,6 +6,7 @@ from collections import OrderedDict MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ + ("BigBirdPegasusConfig", "BigBirdPegasusForQuestionAnswering"), ("BigBirdConfig", "BigBirdForQuestionAnswering"), ("ConvBertConfig", "ConvBertForQuestionAnswering"), ("LEDConfig", "LEDForQuestionAnswering"), diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 4a7140d2ca3..4830e07a2bd 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -310,19 +310,18 @@ class GenerationTesterMixin: logits_processor.append(InfNanRemoveLogitsProcessor()) with torch.no_grad(): - with torch.no_grad(): - output_sample = model.sample( - input_ids_clone, - attention_mask=attention_mask_clone, - max_length=max_length, - logits_processor=logits_processor, - logits_warper=logits_warper, - output_scores=output_scores, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) + output_sample = model.sample( + input_ids_clone, + attention_mask=attention_mask_clone, + max_length=max_length, + logits_processor=logits_processor, + logits_warper=logits_warper, + output_scores=output_scores, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict_in_generate=return_dict_in_generate, + **kwargs, + ) return output_sample, output_generate def _beam_search_generate( diff --git a/tests/test_modeling_bigbird_pegasus.py b/tests/test_modeling_bigbird_pegasus.py new file mode 100644 index 00000000000..bc0b44e8eb3 --- /dev/null +++ b/tests/test_modeling_bigbird_pegasus.py @@ -0,0 +1,762 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch BigBirdPegasus model. """ + + +import copy +import tempfile +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + BigBirdPegasusConfig, + BigBirdPegasusForCausalLM, + BigBirdPegasusForConditionalGeneration, + BigBirdPegasusForQuestionAnswering, + BigBirdPegasusForSequenceClassification, + BigBirdPegasusModel, + PegasusTokenizer, + ) + from transformers.models.bigbird_pegasus.modeling_bigbird_pegasus import ( + BigBirdPegasusDecoder, + BigBirdPegasusEncoder, + ) + +MODEL_ID = "google/bigbird-pegasus-large-pubmed" + + +def prepare_bigbird_pegasus_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + + input_dict = { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + input_dict = {k: input_dict[k].to(torch_device) for k in input_dict} + return input_dict + + +@require_torch +class BigBirdPegasusModelTester: + def __init__( + self, + parent, + batch_size=7, + seq_length=256, + is_training=True, + use_labels=False, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=31, + hidden_act="gelu_fast", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=260, + eos_token_id=1, + pad_token_id=0, + bos_token_id=2, + attention_type="block_sparse", + use_bias=False, + block_size=16, + num_random_blocks=3, + scale_embedding=True, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + + self.attention_type = attention_type + self.use_bias = use_bias + self.block_size = block_size + self.num_random_blocks = num_random_blocks + self.scale_embedding = scale_embedding + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( + 3, + ) + input_ids[:, -1] = self.eos_token_id # Eos Token + + decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + config = BigBirdPegasusConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + encoder_layers=self.num_hidden_layers, + decoder_layers=self.num_hidden_layers, + encoder_attention_heads=self.num_attention_heads, + decoder_attention_heads=self.num_attention_heads, + encoder_ffn_dim=self.intermediate_size, + decoder_ffn_dim=self.intermediate_size, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + pad_token_id=self.pad_token_id, + attention_type=self.attention_type, + use_bias=self.use_bias, + block_size=self.block_size, + num_random_blocks=self.num_random_blocks, + scale_embedding=self.scale_embedding, + ) + inputs_dict = prepare_bigbird_pegasus_inputs_dict(config, input_ids, decoder_input_ids) + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): + model = BigBirdPegasusModel(config=config).get_decoder().to(torch_device).eval() + input_ids = inputs_dict["input_ids"] + attention_mask = inputs_dict["attention_mask"] + + # first forward pass + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + + output, past_key_values = outputs.to_tuple() + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_attn_mask = ids_tensor((self.batch_size, 3), 2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ + "last_hidden_state" + ] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + + def check_encoder_decoder_model_standalone(self, config, inputs_dict): + model = BigBirdPegasusModel(config=config).to(torch_device).eval() + outputs = model(**inputs_dict) + + encoder_last_hidden_state = outputs.encoder_last_hidden_state + last_hidden_state = outputs.last_hidden_state + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder = model.get_encoder() + encoder.save_pretrained(tmpdirname) + encoder = BigBirdPegasusEncoder.from_pretrained(tmpdirname).to(torch_device) + + encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ + 0 + ] + + self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + decoder = model.get_decoder() + decoder.save_pretrained(tmpdirname) + decoder = BigBirdPegasusDecoder.from_pretrained(tmpdirname).to(torch_device) + + last_hidden_state_2 = decoder( + input_ids=inputs_dict["decoder_input_ids"], + attention_mask=inputs_dict["decoder_attention_mask"], + encoder_hidden_states=encoder_last_hidden_state, + encoder_attention_mask=inputs_dict["attention_mask"], + )[0] + + self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + def create_and_check_model(self, config, inputs_dict): + model = BigBirdPegasusModel(config=config).to(torch_device).eval() + input_ids = inputs_dict["input_ids"] + decoder_input_ids = inputs_dict["decoder_input_ids"] + result = model(input_ids, decoder_input_ids=decoder_input_ids, use_cache=True) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + +@require_torch +class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + BigBirdPegasusModel, + BigBirdPegasusForConditionalGeneration, + BigBirdPegasusForSequenceClassification, + BigBirdPegasusForQuestionAnswering, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (BigBirdPegasusForConditionalGeneration,) if is_torch_available() else () + is_encoder_decoder = True + test_missing_keys = False + test_pruning = False + test_head_masking = False + + # torchscript tests are not passing for now. + # Also torchscript is not an important feature to have in the beginning. + test_torchscript = False + + # overwrite from GenerationTesterMixin to solve problem + # with conflicting random seeds + def _get_input_ids_and_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.attention_type = "original_full" + + input_ids = inputs_dict[self.input_name] + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + # cut to half length & take max batch_size 3 + max_batch_size = 2 + sequence_length = input_ids.shape[-1] // 2 + input_ids = input_ids[:max_batch_size, :sequence_length] + attention_mask = attention_mask[:max_batch_size, :sequence_length] + + # generate max 3 tokens + max_length = input_ids.shape[-1] + 3 + if config.eos_token_id is not None and config.pad_token_id is None: + # hack to allow generate for models such as GPT2 as is done in `generate()` + config.pad_token_id = config.eos_token_id + return config, input_ids, attention_mask, max_length + + def setUp(self): + self.model_tester = BigBirdPegasusModelTester(self) + self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_save_load_strict(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + for model_class in self.all_model_classes: + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) + self.assertEqual(info["missing_keys"], []) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_encoder_decoder_model_standalone(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() + self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + + def test_model_various_attn_type(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["original_full", "block_sparse"]: + config_and_inputs[0].attention_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_generate_without_input_ids(self): + if self.model_tester.attention_type == "block_sparse": + # this test can never pass for BigBird-block-sparse attention since input_ids must be multiple of block_size + return + super().test_generate_without_input_ids() + + def test_retain_grad_hidden_states_attentions(self): + if self.model_tester.attention_type == "block_sparse": + # this test can't pass since attention matrix (which is getting returned) can't have gradients (& just 0 at many locations) + return + super().test_retain_grad_hidden_states_attentions() + + # BigBirdPegasusForSequenceClassification does not support inputs_embeds + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in ( + BigBirdPegasusModel, + BigBirdPegasusForConditionalGeneration, + BigBirdPegasusForQuestionAnswering, + ): + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + model(**inputs)[0] + + def test_generate_fp16(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_dict.pop("decoder_attention_mask") + input_dict.pop("decoder_input_ids") + model = BigBirdPegasusForConditionalGeneration(config).eval().to(torch_device) + if torch_device == "cuda": + model.half() + model.generate(**input_dict) + model.generate(**input_dict, do_sample=True, early_stopping=False, num_return_sequences=3) + + def test_batched_forward_original_full(self): + self._check_batched_forward(attn_type="original_full") + + def test_batched_forward_block_sparse(self): + self._check_batched_forward(attn_type="block_sparse", tolerance=1e-1) + + def _check_batched_forward(self, attn_type, tolerance=1e-3): + config = BigBirdPegasusConfig(block_size=16, attention_type=attn_type) + model = BigBirdPegasusForConditionalGeneration(config).to(torch_device) + model.eval() + + sample_with_padding = [3, 8, 11] * 128 + [0] * 128 + sample_without_padding = [4, 7, 9, 13] * 128 + target_ids_without_padding = [2, 3] * 8 + target_ids_with_padding = [7, 8] * 6 + 4 * [-100] + + attention_mask = torch.tensor( + [[1] * 3 * 128 + [0] * 128, [1] * 4 * 128], device=torch_device, dtype=torch.long + ) + + input_ids = torch.tensor([sample_with_padding, sample_without_padding], device=torch_device, dtype=torch.long) + labels = torch.tensor( + [target_ids_without_padding, target_ids_with_padding], device=torch_device, dtype=torch.long + ) + + with torch.no_grad(): + logits_batched = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits + + with torch.no_grad(): + logits_single_first = model(input_ids=input_ids[:1, :-128], labels=labels[:1]).logits + + self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance)) + + with torch.no_grad(): + logits_single_second = model(input_ids=input_ids[1:], labels=labels[1:, :-4]).logits + + self.assertTrue(torch.allclose(logits_batched[1, :3], logits_single_second[0, :3], atol=tolerance)) + + def test_auto_padding(self): + ids = [[7, 6, 9] * 65] + config, _ = self.model_tester.prepare_config_and_inputs() + input_ids = torch.tensor(ids, device=torch_device, dtype=torch.long) + attention_mask = input_ids.new_ones(input_ids.shape) + decoder_input_ids = torch.tensor([[33, 5, 8] * 3], device=torch_device, dtype=torch.long) + + config.block_size = 8 + model = BigBirdPegasusForConditionalGeneration(config).eval().to(torch_device) + output1 = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)[ + "logits" + ] + + ids = [[7, 6, 9] * 65 + [0] * 5] + input_ids = torch.tensor(ids, device=torch_device, dtype=torch.long) + attention_mask = torch.tensor([[1] * 3 * 65 + [0] * 5], device=torch_device, dtype=torch.long) + output2 = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)[ + "logits" + ] + + self.assertTrue(torch.allclose(output1, output2, atol=1e-5)) + + def test_for_change_to_full_attn(self): + self.model_tester.seq_length = 9 + config, input_dict = self.model_tester.prepare_config_and_inputs() + + # automatic switch will happen + config.attention_type = "block_sparse" + model = BigBirdPegasusForConditionalGeneration(config).eval().to(torch_device) + state_dict = model.state_dict() + outputs1 = model(**input_dict)["logits"] + + config.attention_type = "original_full" + model = BigBirdPegasusForConditionalGeneration(config).eval().to(torch_device) + model.load_state_dict(state_dict) + outputs2 = model(**input_dict)["logits"] + + self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5)) + + +@require_torch +@require_sentencepiece +@require_tokenizers +@slow +class BigBirdPegasusModelIntegrationTests(unittest.TestCase): + def _get_dummy_input_ids(self): + # fmt: off + ids = torch.tensor( + [[685, 560, 630, 193, 836, 764, 708, 360, 10, 724, 278, 755, 805, 600, 71, 473, 601, 397, 315, 706, 487, 552, 88, 175, 601, 850, 678, 538, 846, 73, 778, 917, 116, 977, 756, 710, 1023, 848, 432, 449, 851, 100, 985, 178, 756, 798, 660, 148, 911, 424, 289, 962, 266, 698, 640, 545, 544, 715, 245, 152, 676, 511, 460, 883, 184, 29, 803, 129, 129, 933, 54, 902, 551, 489, 757, 274, 336, 389, 618, 43, 443, 544, 889, 258, 322, 1000, 938, 58, 292, 871, 120, 780, 431, 83, 92, 897, 399, 612, 566, 909, 634, 939, 85, 204, 325, 775, 965, 48, 640, 1013, 132, 973, 869, 181, 1001, 847, 144, 661, 228, 955, 792, 720, 910, 374, 854, 561, 306, 582, 170, 676, 449, 96, 198, 607, 257, 882, 691, 293, 931, 817, 862, 388, 611, 555, 974, 369, 1000, 918, 202, 384, 513, 907, 371, 556, 955, 384, 24, 700, 131, 378, 99, 575, 932, 735, 124, 964, 595, 943, 740, 149, 210, 563, 412, 783, 42, 59, 706, 37, 779, 87, 44, 873, 12, 771, 308, 81, 33, 183, 129, 807, 276, 175, 555, 372, 185, 445, 489, 590, 287, 281, 638, 771, 516, 95, 227, 876, 270, 881, 297, 329, 20, 608, 841, 411, 451, 249, 181, 324, 1005, 830, 783, 865, 261, 964, 750, 140, 1021, 599, 462, 890, 622, 844, 697, 529, 153, 926, 150, 111, 26, 465, 957, 890, 887, 118, 446, 596, 674, 873, 929, 229, 508, 764, 122, 327, 470, 288, 526, 840, 697, 153, 592, 42, 275, 553, 439, 208, 780, 167, 112, 350, 1018, 130, 736, 887, 813, 217, 382, 25, 68, 979, 1008, 772, 235, 717, 999, 292, 727, 1023, 702, 710, 728, 556, 33, 12, 617, 213, 139, 695, 1004, 422, 638, 669, 624, 489, 771, 540, 980, 218, 664, 822, 308, 175, 149, 950, 542, 580, 548, 808, 394, 74, 298, 920, 900, 815, 731, 947, 877, 772, 800, 778, 395, 540, 430, 200, 424, 62, 342, 866, 45, 803, 931, 89, 34, 646, 233, 768, 37, 769, 460, 291, 198, 895, 950, 255, 81, 447, 137, 190, 130, 210, 369, 292, 377, 348, 169, 885, 805, 177, 538, 324, 872, 509, 804, 115, 799, 30, 754, 290, 147, 274, 222, 341, 510, 515, 70, 358, 909, 557, 886, 766, 323, 624, 92, 342, 424, 552, 972, 663, 415, 658, 711, 968, 275, 861, 44, 84, 434, 810, 94, 175, 406, 202, 858, 499, 481, 988, 330, 541, 1004, 210, 618, 955, 897, 983, 576, 17, 107, 165, 607, 537, 629, 192, 196, 308, 137, 953, 860, 94, 892, 751, 88, 161, 148, 585, 456, 88, 14, 315, 594, 121, 885, 952, 833, 716, 733, 933, 282, 801, 427, 783, 471, 285, 277, 979, 325, 535, 228, 891, 596, 648, 969, 574, 654, 518, 257, 137, 208, 464, 950, 140, 5, 424, 349, 942, 283, 587, 821, 1007, 434, 220, 820, 740, 874, 787, 374, 291, 564, 671, 438, 827, 940, 824, 509, 1021, 787, 942, 856, 450, 327, 491, 54, 817, 95, 60, 337, 667, 637, 164, 571, 946, 107, 202, 301, 782, 890, 839, 551, 680, 649, 14, 1017, 904, 721, 1017, 535, 505, 848, 986, 777, 740, 775, 210, 456, 469, 474, 963, 573, 401, 57, 883, 750, 664, 281, 5, 613, 1005, 306, 344, 543, 567, 154, 789, 354, 358, 698, 408, 412, 30, 930, 372, 822, 632, 948, 855, 503, 8, 618, 1010, 138, 695, 897, 852, 377, 933, 722, 149, 886, 1009, 260, 127, 811, 578, 533, 805, 325, 977, 113, 944, 651, 238, 361, 991, 860, 556, 64, 928, 917, 455, 266, 445, 604, 624, 420, 340, 845, 275, 370, 843, 227, 226, 940, 644, 909, 229, 827, 898, 370, 129, 808, 25, 699, 293, 356, 838, 135, 4, 227, 890, 681, 445, 418, 285, 837, 27, 737, 249, 366, 948, 202, 438, 198, 930, 648, 638, 607, 73, 247, 853, 136, 708, 214, 476, 621, 324, 103, 853, 328, 596, 224, 257, 646, 348, 108, 927, 970, 980, 520, 150, 998, 477, 393, 684, 559, 1, 361, 692, 551, 90, 75, 500, 739, 636, 344, 97, 852, 283, 719, 33, 116, 455, 866, 429, 828, 826, 691, 174, 746, 133, 442, 94, 348, 402, 420, 707, 405, 942, 186, 976, 376, 677, 874, 703, 517, 498, 499, 206, 415, 366, 856, 739, 420, 586, 219, 952, 539, 375, 23, 461, 720, 355, 603, 52, 999, 815, 721, 574, 445, 816, 1019, 105, 641, 395, 972, 910, 328, 607, 519, 686, 246, 415, 528, 170, 167, 310, 940, 595, 392, 221, 834, 682, 835, 115, 861, 335, 742, 220, 247, 101, 416, 222, 179, 509, 175, 606, 627, 674, 781, 737, 746, 849, 67, 457, 1012, 126, 139, 625, 731, 156, 697, 121, 322, 449, 710, 857, 291, 976, 4, 701, 239, 678, 172, 724, 857, 583, 661, 903, 797, 628, 903, 835, 605, 989, 615, 870, 380, 710, 110, 330, 101, 695, 846, 918, 508, 672, 594, 36, 238, 244, 251, 393, 767, 282, 22, 430, 230, 983, 401, 154, 1007, 120, 678, 896, 386, 390, 711, 397, 347, 587, 1020, 951, 79, 831, 585, 200, 814, 134, 560, 700, 171, 452, 139, 755, 314, 476, 346, 388, 126, 719, 851, 198, 699, 901, 18, 710, 448, 351, 665, 644, 326, 425, 165, 571, 178, 440, 665, 674, 915, 866, 463, 754, 136, 950, 748, 47, 497, 1013, 640, 930, 338, 158, 525, 631, 815, 887, 289, 803, 116, 600, 637, 410, 175, 499, 876, 565, 1002, 623, 577, 333, 887, 586, 147, 773, 776, 644, 49, 77, 294, 117, 494, 561, 110, 979, 180, 562, 72, 859, 434, 1007, 286, 516, 75, 597, 491, 322, 888, 533, 209, 43, 499, 29, 411, 856, 181, 305, 963, 615, 778, 259, 373, 877, 746, 858, 381, 886, 613, 91, 69, 618, 523, 13, 617, 226, 422, 168, 929, 379, 290, 923, 100, 218, 307, 345, 211, 789, 735, 669, 585, 275, 410, 921, 552, 235, 636, 285, 665, 659, 708, 173, 724, 302, 823, 1, 139, 708, 903, 732, 868, 442, 967, 916, 163, 51, 243, 871]], # noqa: E231 + dtype=torch.long, + device=torch_device, + ) + # fmt: on + return ids + + def _get_dummy_target_ids(self): + # fmt: off + ids = torch.tensor( + [[13, 6, 1, 4, 12, 4, 8, 10, 4, 6, 3, 5, 8, 7, 9, 9]], # noqa: E231 + dtype=torch.long, + device=torch_device, + ) + # fmt: on + return ids + + def test_inference_block_sparse(self): + model = BigBirdPegasusForConditionalGeneration.from_pretrained( + MODEL_ID, attention_type="block_sparse", block_size=16, num_random_blocks=3 + ) + model.to(torch_device) + + input_ids = self._get_dummy_input_ids() + target_ids = self._get_dummy_target_ids() + + outputs = model(input_ids, labels=target_ids) + prediction_logits = outputs.logits + + self.assertEqual(prediction_logits.shape, torch.Size((1, 16, 96103))) + # fmt: off + expected_prediction_logits_slice = torch.tensor( + [[1.7769, 5.8479, 6.2375, 2.2745, 8.6157, 4.7483, 5.0647, 6.5358, 2.3393, 7.8333, 3.8403, 0.0255, 7.219, 5.2759, 3.097, 6.387, 4.9341, 7.1409, 5.1179, 0.1144, 6.8268, 0.7598, 0.6258, 2.373, 0.4627, -1.9919, 1.8422, 3.4578], [1.8026, 5.9604, 5.954, 2.8642, 9.0608, 4.394, 5.3779, 7.0216, 1.543, 7.8744, 4.4231, -0.0398, 7.6091, 5.6611, 3.3536, 6.8624, 4.7699, 6.5241, 4.8893, 0.5791, 6.8368, 0.1034, 0.0338, 2.9393, 0.5034, -2.5509, 2.0172, 3.2858], [1.8426, 5.9151, 5.5374, 3.0426, 9.1762, 3.6287, 5.3916, 7.4621, 1.2582, 7.9244, 4.694, -0.1308, 7.4725, 5.5385, 3.4598, 7.0422, 4.2455, 5.797, 4.5927, 0.7478, 6.7467, -0.2695, -0.3207, 3.0269, 0.4714, -2.8134, 2.0406, 3.1089], [1.6527, 5.8416, 5.4558, 3.0044, 9.3478, 3.2607, 5.3887, 7.52, 0.9362, 7.8877, 4.8465, -0.1705, 7.3932, 5.6352, 3.5744, 7.2623, 4.0485, 5.2788, 4.5859, 0.8325, 6.6088, -0.3676, -0.6287, 3.1731, 0.4483, -3.1573, 2.0522, 2.8868]], # noqa: E231 + device=torch_device, + ) + # fmt: on + self.assertTrue( + torch.allclose(prediction_logits[0, 4:8, 128:156], expected_prediction_logits_slice, atol=1e-4) + ) + + def test_inference_full_attn(self): + model = BigBirdPegasusForConditionalGeneration.from_pretrained(MODEL_ID, attention_type="original_full") + model.to(torch_device) + + input_ids = self._get_dummy_input_ids() + target_ids = self._get_dummy_target_ids() + + outputs = model(input_ids, labels=target_ids) + prediction_logits = outputs.logits + + self.assertEqual(prediction_logits.shape, torch.Size((1, 16, 96103))) + # fmt: off + expected_prediction_logits_slice = torch.tensor( + [[1.3418, 5.8304, 6.5662, 2.0448, 8.7702, 4.6579, 4.9947, 6.429, 2.4296, 7.9431, 4.217, 0.0672, 7.334, 5.1966, 2.9603, 6.0814, 4.6756, 7.5522, 5.076, 0.213, 6.6638, 0.6577, 0.244, 2.1221, 0.7531, -2.4076, 1.8731, 3.5594], [1.5525, 6.0524, 6.309, 2.6245, 9.229, 4.5213, 5.0913, 7.0622, 1.7992, 8.0962, 4.7994, -0.0248, 7.7168, 5.5878, 3.0883, 6.5248, 4.7895, 6.9974, 4.8787, 0.5445, 6.6686, 0.0102, -0.1659, 2.6195, 0.7389, -2.8956, 1.9928, 3.3777], [1.6407, 6.2104, 6.0331, 2.8076, 9.4074, 3.9772, 5.0574, 7.5316, 1.4201, 8.3035, 5.0212, -0.1031, 7.553, 5.5023, 3.1427, 6.7674, 4.4409, 6.457, 4.525, 0.728, 6.5422, -0.6234, -0.4726, 2.7486, 0.6985, -3.0804, 1.9669, 3.2365], [1.5065, 6.1271, 5.8296, 2.8405, 9.5649, 3.6834, 5.1214, 7.546, 0.9758, 8.3335, 5.1952, -0.1395, 7.4348, 5.6893, 3.2942, 7.0356, 4.1665, 5.9695, 4.3898, 0.8931, 6.3988, -0.8957, -0.7522, 2.8924, 0.6498, -3.4358, 1.8654, 2.9735]], # noqa: E231 + device=torch_device, + ) + # fmt: on + self.assertTrue( + torch.allclose(prediction_logits[0, 4:8, 128:156], expected_prediction_logits_slice, atol=1e-4) + ) + + def test_seq_to_seq_generation(self): + MODEL_ID = "google/bigbird-pegasus-large-arxiv" + model = BigBirdPegasusForConditionalGeneration.from_pretrained(MODEL_ID).to(torch_device) + tokenizer = PegasusTokenizer.from_pretrained(MODEL_ID) + + ARTICLE_LEP = r"""the lep experiments at the resonance of @xmath1-boson have tested the standard model ( sm ) at quantum level , measuring the @xmath1-decay into fermion pairs with an accuracy of one part in ten thousands . the good agreement of the lep data with the sm predictions have severely constrained the behavior of new physics at the @xmath1-pole . taking these achievements into account one can imagine that the physics of @xmath1-boson will again play the central role in the frontier of particle physics if the next generation @xmath1 factory comes true with the generated @xmath1 events several orders of magnitude higher than that of the lep . this factory can be realized in the gigaz option of the international linear collider ( ilc)@xcite . the ilc is a proposed electron - positron collider with tunable energy ranging from @xmath12 to @xmath13 and polarized beams in its first phase , and the gigaz option corresponds to its operation on top of the resonance of @xmath1 boson by adding a bypass to its main beam line . given the high luminosity , @xmath14 , and the cross section at the resonance of @xmath1 boson , @xmath15 , about @xmath16 @xmath1 events can be generated in an operational year of @xmath17 of gigaz , which implies that the expected sensitivity to the branching ratio of @xmath1-decay can be improved from @xmath18 at the lep to @xmath19 at the gigaz@xcite . in light of this , the @xmath1-boson properties , especially its exotic or rare decays which are widely believed to be sensitive to new physics , should be investigated comprehensively to evaluate their potential in probing new physics . among the rare @xmath1-decays , the flavor changing ( fc ) processes were most extensively studied to explore the flavor texture in new physics @xcite , and it was found that , although these processes are severely suppressed in the sm , their branching ratios in new physics models can be greatly enhanced to @xmath19 for lepton flavor violation decays @xcite and @xmath20 for quark flavor violation decays @xcite . besides the fc processes , the @xmath1-decay into light higgs boson(s ) is another type of rare process that was widely studied , e.g. the decay @xmath21 ( @xmath22 ) with the particle @xmath0 denoting a light higgs boson was studied in @xcite , the decay @xmath23 was studied in the two higgs doublet model ( 2hdm)@xcite and the minimal supersymmetric standard model ( mssm)@xcite , and the decay @xmath4 was studied in a model independent way @xcite , in 2hdm@xcite and also in mssm@xcite . these studies indicate that , in contrast with the kinematic forbidden of these decays in the sm , the rates of these decays can be as large as @xmath18 in new physics models , which lie within the expected sensitivity of the gigaz . in this work , we extend the previous studies of these decays to some new models and investigate these decays altogether . we are motivated by some recent studies on the singlet extension of the mssm , such as the next - to - minimal supersymmetric standard model ( nmssm ) @xcite and the nearly minimal supersymmetric standard model ( nmssm ) @xcite , where a light cp - odd higgs boson @xmath0 with singlet - dominant component may naturally arise from the spontaneous breaking of some approximate global symmetry like @xmath24 or peccei - quuin symmetry @xcite . these non - minimal supersymmetric models can not only avoid the @xmath25-problem , but also alleviate the little hierarchy by having such a light higgs boson @xmath0 @xcite . we are also motivated by that , with the latest experiments , the properties of the light higgs boson are more stringently constrained than before . so it is worth updating the previous studies . so far there is no model - independent lower bound on the lightest higgs boson mass . in the sm , it must be heavier than @xmath26 gev , obtained from the null observation of the higgs boson at lep experiments . however , due to the more complex structure of the higgs sector in the extensions of the sm , this lower bound can be significantly relaxed according to recent studies , e.g. , for the cp - odd higgs boson @xmath0 we have @xmath27 gev in the nmssm @xcite , @xmath28 gev in the nmssm @xcite , and @xmath29 gev in the lepton - specific 2hdm ( l2hdm ) @xcite . with such a light cp - odd higgs boson , the z - decay into one or more @xmath0 is open up . noting that the decay @xmath30 is forbidden due to bose symmetry , we in this work study the rare @xmath1-decays @xmath6 ( @xmath22 ) , @xmath31 and @xmath4 in a comparative way for four models , namely the type - ii 2hdm@xcite , the l2hdm @xcite , the nmssm and the nmssm . in our study , we examine carefully the constraints on the light @xmath0 from many latest experimental results . this work is organized as follows . in sec . ii we briefly describe the four new physics models . in sec . iii we present the calculations of the rare @xmath1-decays . in sec . iv we list the constraints on the four new physics models . in sec . v we show the numerical results for the branching ratios of the rare @xmath1-decays in various models . finally , the conclusion is given in sec . as the most economical way , the sm utilizes one higgs doublet to break the electroweak symmetry . as a result , the sm predicts only one physical higgs boson with its properties totally determined by two free parameters . in new physics models , the higgs sector is usually extended by adding higgs doublets and/or singlets , and consequently , more physical higgs bosons are predicted along with more free parameters involved in . the general 2hdm contains two @xmath32 doublet higgs fields @xmath33 and @xmath34 , and with the assumption of cp - conserving , its scalar potential can be parameterized as@xcite : @xmath35,\end{aligned}\ ] ] where @xmath36 ( @xmath37 ) are free dimensionless parameters , and @xmath38 ( @xmath39 ) are the parameters with mass dimension . after the electroweak symmetry breaking , the spectrum of this higgs sector includes three massless goldstone modes , which become the longitudinal modes of @xmath40 and @xmath1 bosons , and five massive physical states : two cp - even higgs bosons @xmath41 and @xmath42 , one neutral cp - odd higgs particle @xmath0 and a pair of charged higgs bosons @xmath43 . noting the constraint @xmath44 with @xmath45 and @xmath46 denoting the vacuum expectation values ( vev ) of @xmath33 and @xmath34 respectively , we choose @xmath47 as the input parameters with @xmath48 , and @xmath49 being the mixing angle that diagonalizes the mass matrix of the cp - even higgs fields . the difference between the type - ii 2hdm and the l2hdm comes from the yukawa coupling of the higgs bosons to quark / lepton . in the type - ii 2hdm , one higgs doublet @xmath34 generates the masses of up - type quarks and the other doublet @xmath33 generates the masses of down - type quarks and charged leptons ; while in the l2hdm one higgs doublet @xmath33 couples only to leptons and the other doublet @xmath34 couples only to quarks . so the yukawa interactions of @xmath0 to fermions in these two models are given by @xcite @xmath50 with @xmath51 denoting generation index . obviously , in the type - ii 2hdm the @xmath52 coupling and the @xmath53 coupling can be simultaneously enhanced by @xmath54 , while in the l2hdm only the @xmath53 coupling is enhanced by @xmath55 . the structures of the nmssm and the nmssm are described by their superpotentials and corresponding soft - breaking terms , which are given by @xcite @xmath56 where @xmath57 is the superpotential of the mssm without the @xmath25 term , @xmath58 and @xmath59 are higgs doublet and singlet superfields with @xmath60 and @xmath61 being their scalar component respectively , @xmath62 , @xmath63 , @xmath64 , @xmath65 , @xmath66 and @xmath67 are soft breaking parameters , and @xmath68 and @xmath69 are coefficients of the higgs self interactions . with the superpotentials and the soft - breaking terms , one can get the higgs potentials of the nmssm and the nmssm respectively . like the 2hdm , the higgs bosons with same cp property will mix and the mass eigenstates are obtained by diagonalizing the corresponding mass matrices : @xmath70 where the fields on the right hands of the equations are component fields of @xmath71 , @xmath72 and @xmath61 defined by @xmath73 @xmath74 and @xmath75 are respectively the cp - even and cp - odd neutral higgs bosons , @xmath76 and @xmath77 are goldstone bosons eaten by @xmath1 and @xmath78 , and @xmath79 is the charged higgs boson . so both the nmssm and nmssm predict three cp - even higgs bosons , two cp - odd higgs bosons and one pair of charged higgs bosons . in general , the lighter cp - odd higgs @xmath0 in these model is the mixture of the singlet field @xmath80 and the doublet field combination , @xmath81 , i.e. @xmath82 and its couplings to down - type quarks are then proportional to @xmath83 . so for singlet dominated @xmath0 , @xmath84 is small and the couplings are suppressed . as a comparison , the interactions of @xmath0 with the squarks are given by@xcite @xmath85 i.e. the interaction does not vanish when @xmath86 approaches zero . just like the 2hdm where we use the vevs of the higgs fields as fundamental parameters , we choose @xmath68 , @xmath69 , @xmath87 , @xmath88 , @xmath66 and @xmath89 as input parameters for the nmssm@xcite and @xmath68 , @xmath54 , @xmath88 , @xmath65 , @xmath90 and @xmath91 as input parameters for the nmssm@xcite . about the nmssm and the nmssm , three points should be noted . the first is for the two models , there is no explicit @xmath92term , and the effective @xmath25 parameter ( @xmath93 ) is generated when the scalar component of @xmath59 develops a vev . the second is , the nmssm is actually same as the nmssm with @xmath94@xcite , because the tadpole terms @xmath95 and its soft breaking term @xmath96 in the nmssm do not induce any interactions , except for the tree - level higgs boson masses and the minimization conditions . and the last is despite of the similarities , the nmssm has its own peculiarity , which comes from its neutralino sector . in the basis @xmath97 , its neutralino mass matrix is given by @xcite @xmath98 where @xmath99 and @xmath100 are @xmath101 and @xmath102 gaugino masses respectively , @xmath103 , @xmath104 , @xmath105 and @xmath106 . after diagonalizing this matrix one can get the mass eigenstate of the lightest neutralino @xmath107 with mass taking the following form @xcite @xmath108 this expression implies that @xmath107 must be lighter than about @xmath109 gev for @xmath110 ( from lower bound on chargnio mass ) and @xmath111 ( perturbativity bound ) . like the other supersymmetric models , @xmath107 as the lightest sparticle acts as the dark matter in the universe , but due to its singlino - dominated nature , it is difficult to annihilate sufficiently to get the correct density in the current universe . so the relic density of @xmath107 plays a crucial way in selecting the model parameters . for example , as shown in @xcite , for @xmath112 , there is no way to get the correct relic density , and for the other cases , @xmath107 mainly annihilates by exchanging @xmath1 boson for @xmath113 , or by exchanging a light cp - odd higgs boson @xmath0 with mass satisfying the relation @xmath114 for @xmath115 . for the annihilation , @xmath54 and @xmath25 are required to be less than 10 and @xmath116 respectively because through eq.([mass - exp ] ) a large @xmath87 or @xmath25 will suppress @xmath117 to make the annihilation more difficult . the properties of the lightest cp - odd higgs boson @xmath0 , such as its mass and couplings , are also limited tightly since @xmath0 plays an important role in @xmath107 annihilation . the phenomenology of the nmssm is also rather special , and this was discussed in detail in @xcite . in the type - ii 2hdm , l2hdm , nmssm and nmssm , the rare @xmath1-decays @xmath118 ( @xmath22 ) , @xmath3 and @xmath4 may proceed by the feynman diagrams shown in fig.[fig1 ] , fig.[fig2 ] and fig.[fig3 ] respectively . for these diagrams , the intermediate state @xmath119 represents all possible cp - even higgs bosons in the corresponding model , i.e. @xmath41 and @xmath42 in type - ii 2hdm and l2hdm and @xmath41 , @xmath42 and @xmath120 in nmssm and nmssm . in order to take into account the possible resonance effects of @xmath119 in fig.[fig1](c ) for @xmath2 and fig.[fig3 ] ( a ) for @xmath11 , we have calculated all the decay modes of @xmath119 and properly included the width effect in its propagator . as to the decay @xmath121 , two points should be noted . one is , unlike the decays @xmath6 and @xmath11 , this process proceeds only through loops mediated by quarks / leptons in the type - ii 2hdm and l2hdm , and additionally by sparticles in the nmssm and nmssm . so in most cases its rate should be much smaller than the other two . the other is due to cp - invariance , loops mediated by squarks / sleptons give no contribution to the decay@xcite . in actual calculation , this is reflected by the fact that the coupling coefficient of @xmath122 differs from that of @xmath123 by a minus sign ( see eq.([asqsq ] ) ) , and as a result , the squark - mediated contributions to @xmath121 are completely canceled out . with regard to the rare decay @xmath11 , we have more explanations . in the lowest order , this decay proceeds by the diagram shown in fig.[fig3 ] ( a ) , and hence one may think that , as a rough estimate , it is enough to only consider the contributions from fig.[fig3](a ) . however , we note that in some cases of the type - ii 2hdm and l2hdm , due to the cancelation of the contributions from different @xmath119 in fig.[fig3 ] ( a ) and also due to the potentially largeness of @xmath124 couplings ( i.e. larger than the electroweak scale @xmath125 ) , the radiative correction from the higgs - mediated loops may dominate over the tree level contribution even when the tree level prediction of the rate , @xmath126 , exceeds @xmath20 . on the other hand , we find the contribution from quark / lepton - mediated loops can be safely neglected if @xmath127 in the type - ii 2hdm and the l2hdm . in the nmssm and the nmssm , besides the corrections from the higgs- and quark / lepton - mediated loops , loops involving sparticles such as squarks , charginos and neutralinos can also contribute to the decay . we numerically checked that the contributions from squarks and charginos can be safely neglected if @xmath127 . we also calculated part of potentially large neutralino correction ( note that there are totally about @xmath128 diagrams for such correction ! ) and found they can be neglected too . since considering all the radiative corrections will make our numerical calculation rather slow , we only include the most important correction , namely that from higgs - mediated loops , in presenting our results for the four models . one can intuitively understand the relative smallness of the sparticle contribution to @xmath11 as follows . first consider the squark contribution which is induced by the @xmath129 interaction ( @xmath130 denotes the squark in chirality state ) and the @xmath131 interaction through box diagrams . because the @xmath132 interaction conserves the chirality of the squarks while the @xmath133 interaction violates the chirality , to get non - zero contribution to @xmath11 from the squark loops , at least four chiral flippings are needed , with three of them provided by @xmath131 interaction and the rest provided by the left - right squark mixing . this means that , if one calculates the amplitude in the chirality basis with the mass insertion method , the amplitude is suppressed by the mixing factor @xmath134 with @xmath135 being the off diagonal element in squark mass matrix . next consider the chargino / neutralino contributions . since for a light @xmath0 , its doublet component , parameterized by @xmath84 in eq.([mixing ] ) , is usually small , the couplings of @xmath0 with the sparticles will never be tremendously large@xcite . so the chargino / neutralino contributions are not important too . in our calculation of the decays , we work in the mass eigenstates of sparticles instead of in the chirality basis . for the type - ii 2hdm and the l2hdm , we consider the following constraints @xcite : * theoretical constraints on @xmath136 from perturbativity , unitarity and requirements that the scalar potential is finit at large field values and contains no flat directions @xcite , which imply that @xmath137 * the constraints from the lep search for neutral higgs bosons . we compute the signals from the higgs - strahlung production @xmath138 ( @xmath139 ) with @xmath140 @xcite and from the associated production @xmath141 with @xmath142 @xcite , and compare them with the corresponding lep data which have been inputted into our code . we also consider the constraints from @xmath138 by looking for a peak of @xmath143 recoil mass distribution of @xmath1-boson @xcite and the constraint of @xmath144 mev when @xmath145 @xcite . + these constraints limit the quantities such as @xmath146 \times br ( h_i \to \bar{b } b ) $ ] on the @xmath147 plane with the the subscript @xmath148 denoting the coupling coefficient of the @xmath149 interaction . they also impose a model - dependent lower bound on @xmath150 , e.g. , @xmath151 for the type - ii 2hdm ( from our scan results ) , @xmath152 for the l2hdm@xcite , and @xmath153 for the nmssm @xcite . these bounds are significantly lower than that of the sm , i.e. @xmath154 , partially because in new physics models , unconventional decay modes of @xmath155 such as @xmath156 are open up . as to the nmssm , another specific reason for allowing a significantly lighter cp - even higgs boson is that the boson may be singlet - dominated in this model . + with regard to the lightest cp - odd higgs boson @xmath0 , we checked that there is no lower bound on its mass so long as the @xmath157 interaction is weak or @xmath155 is sufficiently heavy . * the constraints from the lep search for a light higgs boson via the yukawa process @xmath158 with @xmath22 and @xmath61 denoting a scalar @xcite . these constraints can limit the @xmath159 coupling versus @xmath160 in new physics models . * the constraints from the cleo - iii limit on @xmath161 and the latest babar limits on @xmath162 . these constraints will put very tight constraints on the @xmath163 coupling for @xmath164 . in our analysis , we use the results of fig.8 in the second paper of @xcite to excluded the unfavored points . * the constraints from @xmath165 couplings . since the higgs sector can give sizable higher order corrections to @xmath165 couplings , we calculate them to one loop level and require the corrected @xmath165 couplings to lie within the @xmath166 range of their fitted value . the sm predictions for the couplings at @xmath1-pole are given by @xmath167 and @xmath168 @xcite , and the fitted values are given by @xmath169 and @xmath170 , respectively@xcite . we adopt the formula in @xcite to the 2hdm in our calculation . * the constraints from @xmath171 leptonic decay . we require the new physics correction to the branching ratio @xmath172 to be in the range of @xmath173 @xcite . we use the formula in @xcite in our calculation . + about the constraints ( 5 ) and ( 6 ) , two points should be noted . one is all higgs bosons are involved in the constraints by entering the self energy of @xmath171 lepton , the @xmath174 vertex correction or the @xmath175 vertex correction , and also the box diagrams for @xmath176@xcite . since the yukawa couplings of the higgs bosons to @xmath171 lepton get enhanced by @xmath54 and so do the corrections , @xmath54 must be upper bounded for given spectrum of the higgs sector . generally speaking , the lighter @xmath0 is , the more tightly @xmath54 is limited@xcite . the other point is in the type - ii 2hdm , @xmath177 , b - physics observables as well as @xmath178 decays discussed above can constraint the model in a tighter way than the constraints ( 5 ) and ( 6 ) since the yukawa couplings of @xmath171 lepton and @xmath179 quark are simultaneously enhanced by @xmath54 . but for the l2hdm , because only the yukawa couplings of @xmath171 lepton get enhanced ( see eq.[yukawa ] ) , the constraints ( 5 ) and ( 6 ) are more important in limiting @xmath54 . * indirect constraints from the precision electroweak observables such as @xmath180 , @xmath181 and @xmath182 , or their combinations @xmath183 @xcite . we require @xmath184 to be compatible with the lep / sld data at @xmath185 confidence level@xcite . we also require new physics prediction of @xmath186 is within the @xmath187 range of its experimental value . the latest results for @xmath188 are @xmath189 ( measured value ) and @xmath190 ( sm prediction ) for @xmath191 gev @xcite . in our code , we adopt the formula for these observables presented in @xcite to the type - ii 2hdm and the l2hdm respectively . + in calculating @xmath180 , @xmath181 and @xmath182 , we note that these observables get dominant contributions from the self energies of the gauge bosons @xmath1 , @xmath192 and @xmath193 . since there is no @xmath194 coupling or @xmath195 coupling , @xmath0 must be associated with the other higgs bosons to contribute to the self energies . so by the uv convergence of these quantities , one can infer that , for the case of a light @xmath0 and @xmath196 , these quantities depend on the spectrum of the higgs sector in a way like @xmath197 at leading order , which implies that a light @xmath0 can still survive the constraints from the precision electroweak observables given the splitting between @xmath150 and @xmath198 is moderate@xcite . * the constraints from b physics observables such as the branching ratios for @xmath199 , @xmath200 and @xmath201 , and the mass differences @xmath202 and @xmath203 . we require their theoretical predications to agree with the corresponding experimental values at @xmath187 level . + in the type - ii 2hdm and the l2hdm , only the charged higgs boson contributes to these observables by loops , so one can expect that @xmath198 versus @xmath54 is to be limited . combined analysis of the limits in the type - ii 2hdm has been done by the ckmfitter group , and the lower bound of @xmath204 as a function of @xmath87 was given in fig.11 of @xcite . this analysis indicates that @xmath198 must be heavier than @xmath205 at @xmath185 c.l . regardless the value of @xmath54 . in this work , we use the results of fig.11 in @xcite to exclude the unfavored points . as for the l2hdm , b physics actually can not put any constraints@xcite because in this model the couplings of the charged higgs boson to quarks are proportional to @xmath206 and in the case of large @xmath54 which we are interested in , they are suppressed . in our analysis of the l2hdm , we impose the lep bound on @xmath198 , i.e. @xmath207@xcite . * the constraints from the muon anomalous magnetic moment @xmath208 . now both the theoretical prediction and the experimental measured value of @xmath208 have reached a remarkable precision , but a significant deviation still exists : @xmath209 @xcite . in the 2hdm , @xmath208 gets additional contributions from the one - loop diagrams induced by the higgs bosons and also from the two - loop barr - zee diagrams mediated by @xmath0 and @xmath155@xcite . if the higgs bosons are much heavier than @xmath25 lepton mass , the contributions from the barr - zee diagrams are more important , and to efficiently alleviate the discrepancy of @xmath208 , one needs a light @xmath0 along with its enhanced couplings to @xmath25 lepton and also to heavy fermions such as bottom quark and @xmath171 lepton to push up the effects of the barr - zee diagram@xcite . the cp - even higgs bosons are usually preferred to be heavy since their contributions to @xmath208 are negative . + in the type - ii 2hdm , because @xmath54 is tightly constrained by the process @xmath210 at the lep@xcite and the @xmath178 decay@xcite , the barr - zee diagram contribution is insufficient to enhance @xmath208 to @xmath187 range around its measured value@xcite . so in our analysis , we require the type - ii 2hdm to explain @xmath208 at @xmath211 level . while for the l2hdm , @xmath54 is less constrained compared with the type - ii 2hdm , and the barr - zee diagram involving the @xmath171-loop is capable to push up greatly the theoretical prediction of @xmath208@xcite . therefore , we require the l2hdm to explain the discrepancy at @xmath187 level . + unlike the other constraints discussed above , the @xmath208 constraint will put a two - sided bound on @xmath54 since on the one hand , it needs a large @xmath54 to enhance the barr - zee contribution , but on the other hand , too large @xmath54 will result in an unacceptable large @xmath208 . * since this paper concentrates on a light @xmath0 , the decay @xmath212 is open up with a possible large decay width . we require the width of any higgs boson to be smaller than its mass to avoid a too fat higgs boson@xcite . we checked that for the scenario characterized by @xmath213 , the coefficient of @xmath214 interaction is usually larger than the electroweak scale @xmath125 , and consequently a large decay width is resulted . for the nmssm and nmssm , the above constraints become more complicated because in these models , not only more higgs bosons are involved in , but also sparticles enter the constraints . so it is not easy to understand some of the constraints intuitively . take the process @xmath199 as an example . in the supersymmetric models , besides the charged higgs contribution , chargino loops , gluino loops as well as neutralino loops also contribute to the process@xcite , and depending on the susy parameters , any of these contributions may become dominated over or be canceled by other contributions . as a result , although the charged higgs affects the process in the same way as that in the type - ii 2hdm , charged higgs as light as @xmath215 is still allowed even for @xmath216@xcite . since among the constraints , @xmath208 is rather peculiar in that it needs new physics to explain the discrepancy between @xmath217 and @xmath218 , we discuss more about its dependence on susy parameters . in the nmssm and the nmssm , @xmath208 receives contributions from higgs loops and neutralino / chargino loops . for the higgs contribution , it is quite similar to that of the type - ii 2hdm except that more higgs bosons are involved in@xcite . for the neutralino / chargino contribution , in the light bino limit ( i.e. @xmath219 ) , it can be approximated by@xcite @xmath220 for @xmath221 with @xmath222 being smuon mass . so combining the two contributions together , one can learn that a light @xmath0 along with large @xmath54 and/or light smuon with moderate @xmath87 are favored to dilute the discrepancy . because more parameters are involved in the constraints on the supersymmetric models , we consider following additional constraints to further limit their parameters : * direct bounds on sparticle masses from the lep1 , the lep2 and the tevatron experiments @xcite . * the lep1 bound on invisible z decay @xmath223 ; the lep2 bound on neutralino production @xmath224 and @xmath225@xcite . * dark matter constraints from the wmap relic density 0.0975 @xmath226 0.1213 @xcite . note that among the above constraints , the constraint ( 2 ) on higgs sector and the constraint ( c ) on neutralino sector are very important . this is because in the supersymmetric models , the sm - like higgs is upper bounded by about @xmath227 at tree level and by about @xmath228 at loop level , and that the relic density restricts the lsp annihilation cross section in a certain narrow range . in our analysis of the nmssm , we calculate the constraints ( 3 ) and ( 5 - 7 ) by ourselves and utilize the code nmssmtools @xcite to implement the rest constraints . we also extend nmssmtools to the nmssm to implement the constraints . for the extension , the most difficult thing we faced is how to adapt the code micromegas@xcite to the nmssm case . we solve this problem by noting the following facts : * as we mentioned before , the nmssm is actually same as the nmssm with the trilinear singlet term setting to zero . so we can utilize the model file of the nmssm as the input of the micromegas and set @xmath229 . * since in the nmssm , the lsp is too light to annihilate into higgs pairs , there is no need to reconstruct the effective higgs potential to calculate precisely the annihilation channel @xmath230 with @xmath61 denoting any of higgs bosons@xcite . we thank the authors of the nmssmtools for helpful discussion on this issue when we finish such extension@xcite . with the above constraints , we perform four independent random scans over the parameter space of the type - ii 2hdm , the l2hdm , the nmssm and the nmssm respectively . we vary the parameters in following ranges : @xmath231 for the type - ii 2hdm , @xmath232 for the l2hdm , @xmath233 for the nmssm , and @xmath234 for the nmssm . in performing the scans , we note that for the nmssm and the nmssm , some constraints also rely on the gaugino masses and the soft breaking parameters in the squark sector and the slepton sector . since these parameters affect little on the properties of @xmath0 , we fix them to reduce the number of free parameters in our scan . for the squark sector , we adopt the @xmath235 scenario which assumes that the soft mass parameters for the third generation squarks are degenerate : @xmath236 800 gev , and that the trilinear couplings of the third generation squarks are also degenerate , @xmath237 with @xmath238 . for the slepton sector , we assume all the soft - breaking masses and trilinear parameters to be 100 gev . this setting is necessary for the nmssm since this model is difficult to explain the muon anomalous moment at @xmath239 level for heavy sleptons@xcite . finally , we assume the grand unification relation @xmath240 for the gaugino masses with @xmath241 being fine structure constants of the different gauge group . with large number of random points in the scans , we finally get about @xmath242 , @xmath243 , @xmath244 and @xmath242 samples for the type - ii 2hdm , the l2hdm , the nmssm and the nmssm respectively which survive the constraints and satisfy @xmath245 . analyzing the properties of the @xmath0 indicates that for most of the surviving points in the nmssm and the nmssm , its dominant component is the singlet field ( numerically speaking , @xmath246 ) so that its couplings to the sm fermions are suppressed@xcite . our analysis also indicates that the main decay products of @xmath0 are @xmath247 for the l2hdm@xcite , @xmath248 ( dominant ) and @xmath247 ( subdominant ) for the type - ii 2hdm , the nmssm and the nmssm , and in some rare cases , neutralino pairs in the nmssm@xcite . in fig.[fig4 ] , we project the surviving samples on the @xmath249 plane . this figure shows that the allowed range of @xmath54 is from @xmath250 to @xmath251 in the type - ii 2hdm , and from @xmath252 to @xmath253 in the l2hdm . just as we introduced before , the lower bounds of @xmath254 come from the fact that we require the models to explain the muon anomalous moment , while the upper bound is due to we have imposed the constraint from the lep process @xmath255 , which have limited the upper reach of the @xmath256 coupling for light @xmath61 @xcite(for the dependence of @xmath256 coupling on @xmath54 , see sec . this figure also indicates that for the nmssm and the nmssm , @xmath54 is upper bounded by @xmath257 . for the nmssm , this is because large @xmath87 can suppress the dark matter mass to make its annihilation difficult ( see @xcite and also sec . ii ) , but for the nmssm , this is because we choose a light slepton mass so that large @xmath54 can enhance @xmath208 too significantly to be experimentally unacceptable . we checked that for the slepton mass as heavy as @xmath258 , @xmath259 is still allowed for the nmssm . in fig.[fig5 ] and fig.[fig6 ] , we show the branching ratios of @xmath260 and @xmath261 respectively . fig.[fig5 ] indicates , among the four models , the type - ii 2hdm predicts the largest ratio for @xmath260 with its value varying from @xmath262 to @xmath263 . the underlying reason is in the type - ii 2hdm , the @xmath264 coupling is enhanced by @xmath54 ( see fig.[fig4 ] ) , while in the other three model , the coupling is suppressed either by @xmath265 or by the singlet component of the @xmath0 . fig.[fig6 ] shows that the l2hdm predicts the largest rate for @xmath266 with its value reaching @xmath5 in optimum case , and for the other three models , the ratio of @xmath261 is at least about one order smaller than that of @xmath267 . this feature can be easily understood from the @xmath268 coupling introduced in sect . we emphasize that , if the nature prefers a light @xmath0 , @xmath260 and/or @xmath269 in the type - ii 2hdm and the l2hdm will be observable at the gigaz . then by the rates of the two decays , one can determine whether the type - ii 2hdm or the l2hdm is the right theory . on the other hand , if both decays are observed with small rates or fail to be observed , the singlet extensions of the mssm are favored . in fig.[fig7 ] , we show the rate of @xmath3 as the function of @xmath270 . this figure indicates that the branching ratio of @xmath121 can reach @xmath271 , @xmath272 , @xmath273 and @xmath274 for the optimal cases of the type - ii 2hdm , the l2hdm , the nmssm and the nmssm respectively , which implies that the decay @xmath121 will never be observable at the gigaz if the studied model is chosen by nature . the reason for the smallness is , as we pointed out before , that the decay @xmath121 proceeds only at loop level . comparing the optimum cases of the type - ii 2hdm , the nmssm and the nmssm shown in fig.5 - 7 , one may find that the relation @xmath275 holds for any of the decays . this is because the decays are all induced by the yukawa couplings with similar structure for the models . in the supersymmetric models , the large singlet component of the light @xmath0 is to suppress the yukawa couplings , and the @xmath0 in the nmssm has more singlet component than that in the nmssm . next we consider the decay @xmath11 , which , unlike the above decays , depends on the higgs self interactions . in fig.[fig8 ] we plot its rate as a function of @xmath270 and this figure indicates that the @xmath276 may be the largest among the ratios of the exotic @xmath1 decays , reaching @xmath277 in the optimum cases of the type - ii 2hdm , the l2hdm and the nmssm . the underlying reason is , in some cases , the intermediate state @xmath119 in fig.[fig3 ] ( a ) may be on - shell . in fact , we find this is one of the main differences between the nmssm and the nmssm , that is , in the nmssm , @xmath119 in fig.[fig3 ] ( a ) may be on - shell ( corresponds to the points with large @xmath278 ) while in the nmssm , this seems impossible . so we conclude that the decay @xmath11 may serve as an alternative channel to test new physics models , especially it may be used to distinguish the nmssm from the nmssm if the supersymmetry is found at the lhc and the @xmath11 is observed at the gigaz with large rate . before we end our discussion , we note that in the nmssm , the higgs boson @xmath0 may be lighter than @xmath279 without conflicting with low energy data from @xmath178 decays and the other observables ( see fig.[fig4]-[fig8 ] ) . in this case , @xmath0 is axion - like as pointed out in @xcite . we checked that , among the rare @xmath1 decays discussed in this paper , the largest branching ratio comes from @xmath280 which can reach @xmath281 . since in this case , the decay product of @xmath0 is highly collinear muon pair , detecting the decay @xmath280 may need some knowledge about detectors , which is beyond our discussion . in this paper , we studied the rare @xmath1-decays @xmath2 ( @xmath7 ) , @xmath282 and @xmath4 in the type - ii 2hdm , lepton - specific 2hdm , nmssm and nmssm , which predict a light cp - odd higgs boson @xmath0 . in the parameter space allowed by current experiments , the branching ratio can be as large as @xmath5 for @xmath118 , @xmath8 for @xmath3 and @xmath9 for @xmath4 , which implies that the decays @xmath2 and @xmath283 may be accessible at the gigaz option . since different models predict different size of branching ratios , these decays can be used to distinguish different model through the measurement of these rare decays . this work was supported in part by hastit under grant no . 2009hastit004 , by the national natural science foundation of china ( nnsfc ) under grant nos . 10821504 , 10725526 , 10635030 , 10775039 , 11075045 and by the project of knowledge innovation program ( pkip ) of chinese academy of sciences under grant no . . for some reviews , see , e.g. , m. a. perez , g. tavares - velasco and j. j. toscano , int . j. mod . a * 19 * , 159 ( 2004 ) ; j. m. yang , arxiv:1006.2594 . j. i. illana , m. masip , 67 , 035004 ( 2003 ) ; j. cao , z. xiong , j. m. yang , 32 , 245 ( 2004 ) . d. atwood _ et al_. , 66 , 093005 ( 2002 ) . j. kalinowski , and s. pokorski , 219 , 116 ( 1989 ) ; a. djouadi , p. m. zerwas and j. zunft , 259 , 175 ( 1991 ) ; a. djouadi , j. kalinowski , and p. m. zerwas , z. phys . c * 54 * , 255 ( 1992 ) . m. krawczyk , _ et al . _ , 19 , 463 ( 2001 ) ; 8 , 495 ( 1999 ) . j. f. gunion , g. gamberini and s. f. novaes , 38 , 3481 ( 1988 ) ; thomas j. weiler and tzu - chiang yuan , 318 , 337 ( 1989 ) ; a. djouadi , _ et al . _ , 1 , 163 ( 1998)[hep - ph/9701342 ] . d. chang and w. y. keung , phys . lett . * 77 * , 3732 ( 1996 ) . e. keith and e. ma , 57 , 2017 ( 1998 ) ; m. a. perez , g. tavares - velasco and j. j. toscano , int . j. mod.phys . a * 19 * , 159 ( 2004 ) . f. larios , g. tavares - velasco and c. p. yuan , 64 , 055004 ( 2001 ) ; 66 , 075006 ( 2002 ) . a. djouadi , _ et al . _ , 10 , 27 ( 1999 ) [ hep - ph/9903229 ] . for a detailed introduction of the nmssm , see f. franke and h. fraas , int . j. mod . a * 12 * ( 1997 ) 479 ; for a recent review of the nmssm , see for example , u. ellwanger , c. hugonie , and a. m. teixeira , arxiv : 0910.1785 . see , e.g. , j. r. ellis , j. f. gunion , h. e. haber , l. roszkowski and f. zwirner , phys . rev . d * 39 * ( 1989 ) 844 ; m. drees , int . j. mod . phys . a * 4 * ( 1989 ) 3635 ; u. ellwanger , m. rausch de traubenberg and c. a. savoy , phys . b * 315 * ( 1993 ) 331 ; nucl . b * 492 * ( 1997 ) 21 ; d.j . miller , r. nevzorov , p.m. zerwas , 681 , 3 ( 2004 ) . c. panagiotakopoulos , k. tamvakis , 446 , 224 ( 1999 ) ; 469 , 145 ( 1999 ) ; c. panagiotakopoulos , a. pilaftsis , 63 , 055003 ( 2001 ) ; a. dedes , _ et al . _ , 63 , 055009 ( 2001 ) ; a. menon , _ et al . _ , 70 , 035005 ( 2004 ) ; v. barger , _ et al . _ , 630 , 85 ( 2005 ) . c. balazs , _ et al . _ , 0706 , 066 ( 2007 ) . b. a. dobrescu , k. t. matchev , 0009 , 031 ( 2000 ) ; a. arhrib , k. cheung , t. j. hou , k. w. song , hep - ph/0611211 ; 0703 , 073 ( 2007 ) ; x. g. he , j. tandean , and g. valencia , 98 , 081802 ( 2007 ) ; 0806 , 002 ( 2008 ) ; f. domingo _ et al_. , 0901 , 061 ( 2009 ) ; gudrun hiller , 70 , 034018 ( 2004 ) ; r. dermisek , and john f. gunion , 75 , 075019 ( 2007 ) ; 79 , 055014 ( 2009 ) ; 81 , 055001 ( 2010 ) ; r. dermisek , john f. gunion , and b. mcelrath , 76 , 051105 ( 2007 ) ; z. heng , _ et al_. , 77 , 095012 ( 2008 ) ; a. belyaev _ et al_. , 81 , 075021 ( 2010 ) ; d. das and u. ellwanger , arxiv:1007.1151 [ hep - ph ] . s. andreas , o. lebedev , s. ramos - sanchez and a. ringwald , arxiv:1005.3978 [ hep - ph ] . j. f. gunion , jhep * 0908 * , 032 ( 2009 ) ; r. dermisek and j. f. gunion , phys . rev . d * 81 * , 075003 ( 2010 ) . r. dermisek and j. f. gunion , phys . lett . * 95 * , 041801 ( 2005 ) ; phys . d * 73 * , 111701 ( 2006 ) . j. cao , h. e. logan , j. m. yang , 79 , 091701 ( 2009 ) . j. cao , p. wan , l. wu , j. m. yang , 80 , 071701 ( 2009 ) . j. f. gunion and h. e. haber , 67 , 075019 ( 2003 ) . r. m. barnett , _ et al . _ , phys . b * 136 * , 191 ( 1984 ) ; r. m. barnett , g. senjanovic and d. wyler , phys . d * 30 * , 1529 ( 1984 ) ; y. grossman , nucl . b * 426 * , 355 ( 1994 ) . h. s. goh , l. j. hall and p. kumar , jhep * 0905 * , 097 ( 2009 ) ; a. g. akeroyd and w. j. stirling , nucl . b * 447 * , 3 ( 1995 ) ; a. g. akeroyd , phys . b * 377 * , 95 ( 1996 ) ; h. e. logan and d. maclennan , phys . rev . d * 79 * , 115022 ( 2009 ) ; m. aoki , _ et al . _ , arxiv:0902.4665 [ hep - ph ] . v. barger , p. langacker , h. s. lee and g. shaughnessy , phys . d * 73 * , 115010 ( 2006 ) . s. hesselbach , _ et . _ , arxiv:0810.0511v2 [ hep - ph ] . de vivie and p. janot [ aleph collaboration ] , pa13 - 027 contribution to the international conference on high energy physics , warsaw , poland , 2531 july 1996 ; j. kurowska , o. grajek and p. zalewski [ delphi collaboration ] , cern - open-99 - 385 . [ aleph collaboration and delphi collaboration and l3 collaboration ] , phys . rept . * 427 * , 257 ( 2006 ) . j. cao and j. m. yang , jhep * 0812 * , 006 ( 2008 ) . m. krawczyk and d. temes , eur . j. c * 44 * , 435 ( 2005 ) . g. altarelli and r. barbieri , 253 , 161 ( 1991 ) ; m. e. peskin , t. takeuchi , 46 , 381 ( 1992 ) . c. amsler , _ et al . _ , ( particle data group ) , 667 , 1 ( 2008 ) . o. deschamps , s. descotes - genon , s. monteil , v. niess , s. tjampens and v. tisserand , arxiv:0907.5135 [ hep - ph ] . s. su and b. thomas , phys . d * 79 * , 095014 ( 2009 ) . g. abbiendi , _ et al . _ , eur . phys . j. c * 32 * , 453 ( 2004 ) . m. davier , _ et al . _ , 66 , 1 ( 2010 ) . k. cheung , _ et al . _ , phys . d * 64 * , 111301 ( 2001 ) . k. cheung and o. c. w. kong , phys . d * 68 * , 053003 ( 2003 ) . t. besmer , c. greub , t.hurth , 609 , 359 ( 2001 ) ; f. borzumati , _ et al . _ , 62 , 075005(2000 ) . j. cao , k. i. hikasa , w. wang , j. m. yang and l. x. yu , phys . d * 82 * , 051701 ( 2010 ) [ arxiv:1006.4811 [ hep - ph ] ] . j. f. gunion , _ et . d * 73 * , 015011 ( 2006 ) . martin and j. d. wells , phys . d * 64 * , 035003 ( 2001 ) . j. abdallah _ et al . _ , eur . j. c * 31 * , 421 ( 2004 ) ; g. abbiendi _ et al . _ , eur . j. c * 35 * , 1 ( 2004 ) . j. dunkley _ et al . _ [ wmap collaboration ] , astrophys . j. suppl . * 180 * , 306 ( 2009 ) [ arxiv:0803.0586 [ astro - ph ] ] . u. ellwanger _ et al . _ , 02 , 066 ( 2005 ) . g. belanger , f. boudjema , a. pukhov and a. semenov , comput . commun . * 174 * , 577 ( 2006 ) ; comput . phys . commun . * 176 * , 367 ( 2007 ) . g. belanger , f. boudjema , c. hugonie , a. pukhov and a. semenov , jcap * 0509 * , 001 ( 2005 ) .""" + + ARTICLE_MAGNET = r"""it is well known that the classical magnetoresistance ( mr ) in metals or semiconductors with a closed free electron fermi surface increases quadratically with increasing magnetic field @xmath2 for @xmath3 and saturates when @xmath4 . here @xmath5 is the zero - magnetic - field mobility . hence , the extraordinarily high and linear mr ( lmr ) , which breaks this familiar rule , has been gaining much attention as soon as its discovery . in the past decade , this unexpected lmr has been reported in silver chalcogenide,@xcite indium antimonide,@xcite silicon,@xcite mnas - gaas composite material,@xcite and graphene.@xcite kapitza s linear law@xcite indicates that the metal shows a magnetoresistance linear in perpendicular magnetic field when it has an open fermi surface and a mean free path longer than the electronic larmor radius . recently , another two models , irrespective of the open fermi surface , have been constructed to provide possible mechanisms for the lmr phenomenon . abrikosov suggested a quantum - limit origin of lmr for the homogenous system with a gapless linear energy spectrum.@xcite his model requires that landau levels are well formed and the carrier concentration is small that all electrons occupy only the lowest landau band . alternatively , parish and littlewood developed a classical model without involving linear spectrum.@xcite ignoring the concrete microscopic mechanism , they attributed this unusual mr to the mobility fluctuations in a strongly inhomogenous system . topological insulators@xcite ( tis ) are novel materials with a full energy gap in bulk , while there are gapless surface states . due to its unique band structure with only one helical dirac cone and linear energy dispersion,@xcite the surface states of the ti bi@xmath0se@xmath1 become an excellent platform for the study of quantum - limit lmr . the recent experiment in this flat surface system , however , reported that a large positive mr , which becomes very linear above a characteristic field of @xmath6@xmath7@xmath8 t , was observed even in an opposite situation where the carrier sheet density is high that electrons occupy more than one landau levels.@xcite moreover , they found that raising temperature to room temperature almost has no influence on the observed lmr . it is striking that this observation is in conflict with abrikosov s model and also with the classical parish - littlewood model . so far a reliable theoretical scheme capable of explaining this novel experiment has still been lacking . in this paper , we generalize the balance - equation approach@xcite to a system modeling the surface states of a three - dimensional ti to investigate the two - dimensional magnetotransport in it . we find that a positive , nonsaturating and dominantly linear magnetoresistance can appear within quite wide magnetic - field range in the ti surface state having a positive and finite effective g - factor . this linear magnetoresistance shows up in the system of high carrier concentration and low mobility when electrons are in extended states and spread over many smeared landau levels , and persists up to room temperature , providing a possible mechanism for the recently observed linear magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons.@xcite we consider the surface state of a bi@xmath0se@xmath1-type large bulk gap ti in the @xmath9-@xmath10 plane under the influence of a uniform magnetic field @xmath11 applied along the @xmath12 direction.@xcite following the experimental observation,@xcite we assume that the fermi energy locates in the gap of the bulk band and above the dirac point , i.e. the surface carriers are electrons . further , the separations of the fermi energy from the bottom of bulk band and dirac point are much larger than the highest temperature ( @xmath13 ) considered in this work . hence , the contribution from the bulk band to the magnetotransport is negligible . these electrons , scattered by randomly distributed impurities and by phonons , are driven by a uniform in - plane electric field @xmath14 in the topological surface . the hamiltonian of this many - electron and phonon system consists of an electron part @xmath15 , a phonon part @xmath16 , and electron - impurity and electron - phonon interactions @xmath17 and @xmath18 : @xmath19 here , the electron hamiltonian is taken in the form @xmath20 , \ ] ] in which @xmath21 , @xmath22 , @xmath23 and @xmath24 , stand , respectively , for the canonical momentum , coordinate , momentum and spin operators of the @xmath25th electron having charge @xmath26 , @xmath27 is the vector potential of the perpendicular magnetic field @xmath28 in the landau gauge , @xmath29 is the fermi velocity , @xmath30 is the effective g - factor of the surface electron , and @xmath31 is the bohr magneton with @xmath32 the free electron mass . the sum index @xmath25 in eq.([helectron ] ) goes over all electrons of total number @xmath33 in the surface state of unit area . in the frame work of balance equation approach,@xcite the two - dimensional center - of - mass ( c.m . ) momentum and coordinate @xmath34 and @xmath35 , and the relative - electron momenta and coordinates @xmath36 and @xmath37 are introduced to write the hamiltonian @xmath15 into the sum of a single - particle c.m . part @xmath38 and a many - particle relative - electron part @xmath39 : @xmath40 , with @xmath41.\end{aligned}\ ] ] in this , @xmath42 is the canonical momentum of the center - of - mass and @xmath43 is the canonical momentum for the @xmath25th relative electron . here we have also introduced c.m . spin operators @xmath44 and @xmath45 . the commutation relations between the c.m . spin operators @xmath46 and @xmath47 and the spin operators @xmath48 , @xmath49 and @xmath50 of the @xmath25th electron are of order of @xmath51 : @xmath52= n^{-1}2\,{\rm i}\,\varepsi lon_{\beta_1\beta_2\beta_3}\sigma_j^{\beta_3}$ ] with @xmath53 . therefore , for a macroscopic large @xmath33 system , the c.m . part @xmath38 actually commutes with the relative - electron part @xmath54 in the hamiltonian , i.e. the c.m . motion and the relative motion of electrons are truly separated from each other . the couplings between the two emerge only through the electron impurity and electron phonon interactions . furthermore , the electric field @xmath55 shows up only in @xmath38 . and , in view of @xmath56={\rm i}\delta_{\alpha \beta}(\delta_{ij}-1/n)\simeq { \rm i}\delta_{\alpha\beta}\delta_{ij}$ ] , i.e. the relative - electron momenta and coordinates can be treated as canonical conjugate variables , the relative - motion part @xmath54 is just the hamiltonian of @xmath33 electrons in the surface state of ti in the magnetic field without the presence of the electric field . in terms of the c.m . coordinate @xmath57 and the relative electron density operator @xmath58 , the electron impurity and electron phonon interactions can be written as@xcite @xmath59 here @xmath60 and @xmath61 are respectively the impurity potential ( an impurity at randomly distributed position @xmath62 ) and electron phonon coupling matrix element in the plane - wave representation , and @xmath63 with @xmath64 and @xmath65 being the creation and annihilation operators for a phonon of wavevector @xmath66 in branch @xmath67 having frequency @xmath68 . velocity ( operator ) @xmath69 is the time variation of its coordinate : @xmath70= v_{\rm f}(\sigma_{\rm c}^y\ , \hat{i}-\sigma_{\rm c}^x\ , \hat{j})$ ] . to derive a force - balance equation for steady state transport we consider the heisenberg equation for the rate of change of the c.m . canonical momentum @xmath71 : @xmath72= - n e({\bm v}\times { \bm b})- n e{\bm e}+{\bm { f}}_{\rm i}+{\bm { f}}_{\rm p},\ ] ] in which the frictional forces @xmath73 and @xmath74 share the same expressions as given in ref .. the statistical average of the operator equation can be determined to linear order in the electron impurity and electron phonon interactions @xmath17 and @xmath18 with the initial density matrix @xmath75 at temperature @xmath76 when the in - plane electric field @xmath77 is not strong . for steady - transport states we have @xmath78 , leading to a force - balance equation of the form @xmath79 here @xmath80 , the statistically averaged velocity of the moving center - of - mass , is identified as the average rate of change of its position , i.e. the drift velocity of the electron system driven by the electric field @xmath77 , and @xmath81 and @xmath82 are frictional forces experienced by the center - of - mass due to impurity and phonon scatterings : @xmath83,\label{fp}\end{aligned}\ ] ] in which @xmath84 is the bose distribution function , @xmath85 , and @xmath86 stands for the imaginary part of the fourier spectrum of the relative - electron density correlation function defined by @xmath87\big\rangle_{0},\ ] ] where @xmath88 and @xmath89 denotes the statistical averaging over the initial density matrix @xmath90.@xcite the force - balance equation describes the steady - state two - dimensional magnetotransport in the surface state of a ti . note that the frictional forces @xmath81 and @xmath82 are in the opposite direction of the drift velocity @xmath91 and their magnitudes are functions of @xmath92 only . with the drift velocity @xmath93 in the @xmath9 direction , the force - balance equation eq . yields a transverse resistivity @xmath94 , and a longitudinal resistivity @xmath95 . the linear one is in the form @xmath96 for calculating the electron density correlation function @xmath97 we proceed in the landau representation.@xcite the landau levels of the single - particle hamiltonian @xmath98 of the relative - electron system in the absence of electric field are composed of a positive `` @xmath99 '' and a negative `` @xmath100 '' branch@xcite @xmath101 with @xmath102 and @xmath103 , and a zero ( @xmath104 ) level @xmath105 the corresponding landau wave functions are @xmath106 and @xmath107 for @xmath108 ; and @xmath109 for @xmath104 . here @xmath110 is the wavevector of the system along @xmath9 direction ; @xmath111 with @xmath112 ; and @xmath113 is the harmonic oscillator eigenfunction with @xmath114 being the hermite polynomial , @xmath115 , and @xmath116 . each landau level contains @xmath117 electron states for system of unit surface area . the positive branch @xmath118 and the @xmath104 level @xmath119 of the above energy spectra are indeed quite close to those of the surface states in the bulk gap of bi@xmath0se@xmath1-family materials derived from microscopic band calculation.@xcite the landau levels are broadened due to impurity , phonon and electron - electron scatterings . we model the imaginary part of the retarded green s function , or the density - of - states , of the broadened landau level @xmath120 ( written for `` + ' ' -branch and @xmath104 levels ) , using a gaussian - type form:@xcite @xmath121,\ ] ] with a half - width @xmath122 of the form:@xcite @xmath123^{1/2}$ ] . here @xmath124 is the single - particle lifetime and @xmath125 is the cyclotron frequency of linear - energy - dispersion system with @xmath126 being the zero - temperature fermi level . using a semi - empirical parameter @xmath127 to relate @xmath124 with the transport scattering time @xmath128 , and expressing @xmath129 with the zero - field mobility @xmath5 at finite temperature,@xcite we can write the landau - level broadening as @xmath130^{1/2}.\ ] ] in the present study we consider the case of @xmath120-doping , i.e. the fermi level is high enough above the energy zero of the dirac cone in the range of `` + ' ' -branch levels and the states of `` @xmath100''-branch levels are completely filled , that they are irrelevant to electron transport . special attention has to be paid to the @xmath104 level , since , depending on the direction of exchange potential the effective g - factor of a ti surface state , @xmath30 , can be positive , zero or negative.@xcite the sign and magnitude of the effective g - factor determines how many states of the zero level should be included in or excluded from the available states for electron occupation in the case of @xmath120-doping at a magnetic field . ( i ) if @xmath131 , the @xmath104 level center is exactly at @xmath132 and the system is electron - hole symmetric . the total number of negative energy states ( including the states of the lower half of the @xmath104 level and states of the @xmath100"-branch levels ) and that of positive energy states ( including the states of the upper half of the @xmath104 level and states of the @xmath99"-branch levels ) do not change when changing magnetic field . therefore , the lower - half negative energy states of this level are always filled and the upper - half positive - energy states of it are available for the occupation of particles which are counted as electrons participating in transport in the case of @xmath120-doping . ( ii ) for a finite positive @xmath133 , the @xmath104 level @xmath134 moves downward to negative energy and its distance to the nearest @xmath100"-branch level is @xmath135 closer than to the nearest + " -branch level at finite magnetic field strength @xmath2 . this is equivalent to the opening of an increasingly enlarged ( with increasing @xmath2 ) energy gap between the + " -branch states and the states of the zero - level and the @xmath100"-branch levels . the opening of a sufficient energy gap implies that with increasing magnetic field the states in the + " -branch levels would no longer shrink into the zero - level , and thus the @xmath104 level should be completely excluded from the conduction band , i.e. only particles occupying the + " -branch states are counted as electrons participating in transport in the case of @xmath120-doping , when the magnetic field @xmath2 gets larger than a certain value ( depending on the magnitude of @xmath30 ) . ( iii ) for a finite negative @xmath136 , the @xmath104 level @xmath134 moves upward to positive energy and an increasingly enlarged energy gap will be opened between the states of the zero - level and the + " -branch and the states of @xmath100"-branch levels , and particles occupying the @xmath104 level and + " -branch states are electrons participating in transport when the magnetic field @xmath2 gets larger than a certain value . as a result , the experimentally accessible sheet density @xmath33 of electrons participating in transport is related to the fermi energy @xmath137 by the following equation valid at finite @xmath30 for the magnetic field @xmath2 larger than a certain value : @xmath138 in which @xmath139 + 1\}^{-1}$ ] is the fermi distribution function at temperature @xmath76 and the summation index @xmath120 goes over @xmath140 for @xmath133 , or @xmath141 for @xmath136 . in the case of @xmath131 , @xmath142\ ] ] valid for arbitrary magnetic field , in which @xmath143 . the imaginary part of relative - electron density correlation function in the presence of a magnetic field , @xmath86 , can be expressed in the landau representation as@xcite @xmath144 in which the transform factor @xmath145 ^ 2,\end{aligned}\ ] ] with @xmath146 , @xmath147 , @xmath148 , and @xmath149 being associated laguerre polynomials . the landau - representation correlation function @xmath150 in eq.([piqw ] ) can be constructed with the imaginary part of the retarded green s function @xmath151 , or the density - of - states , of the @xmath120th landau level as@xcite @xmath152\nonumber\\ & \hspace{1.2cm}\times{\rm im}g_n(\epsilon+\omega){\rm im}g_{n'}(\epsilon).\end{aligned}\ ] ] the summation indices @xmath120 and @xmath153 in eq.([piqw ] ) are taken over @xmath140 for @xmath133 , or @xmath154 for @xmath136 . in the case of @xmath131 , eq.([piqw ] ) still works and the summation indices @xmath120 and @xmath153 go over @xmath154 but with @xmath155 replaced by @xmath156 in eq.([p2nn ] ) . numerical calculations are performed for the magnetoresistivity @xmath157 of surface state in a uniform ti bi@xmath0se@xmath1 . at zero temperature the elastic scattering contributing to the resistivity is modeled by a coulomb potential due to charged impurities:@xcite @xmath158 with @xmath159 being the impurity density , which is determined by the zero - magnetic - field mobility @xmath5 . at temperatures higher than @xmath160,@xcite phonon scatterings play increasingly important role and the dominant inelastic contribution comes from optical phonons . for this polar material , the scattering by optical phonons via the deformation potential can be neglected . hence , we take account of inelastic scattering from optical phonons via frhlich coupling : @xmath161 . in the numerical calculation we use the following parameters:@xcite fermi velocity @xmath162 , static dielectric constant @xmath163 , optical dielectric constant @xmath164 , and phonon energy @xmath165 . the broadening parameter is taken to be @xmath166 . as a function of the magnetic field @xmath2 having different effective g - factors : @xmath167 and @xmath168 for a ti surface system with electron sheet density @xmath169 in the cases of zero - magnetic - field mobility @xmath170 ( a ) and @xmath171 ( b ) . several integer - number positions of filling factor @xmath172 are marked in ( b).,scaledwidth=40.0% ] fig.[diffg ] shows the calculated magnetoresistivity @xmath157 versus the magnetic field strength @xmath2 for a ti surface system with electron sheet density @xmath169 but having different effective g - factors : @xmath167 and @xmath168 for two values of zero - magnetic - field mobility @xmath170 and @xmath171 , representing different degree of landau - level broadening . in the case without zeeman splitting ( @xmath131 ) the resistivity @xmath157 exhibits almost no change with changing magnetic field up to 10 t , except the shubnikov - de haas ( sdh ) oscillation showing up in the case of @xmath171 . this kind of magnetoresistance behavior was indeed seen experimentally in the electron - hole symmetrical massless system of single - layer graphene.@xcite in the case of a positive g - factor , @xmath173 , the magnetoresistivity increases linearly with increasing magnetic field ; while for a negative g - factor , @xmath174 , the magnetoresistivity decreases linearly with increasing magnetic field . is shown as a function of the magnetic field @xmath2 for different values of zero - magnetic - field mobility : ( a ) @xmath175 , ( b ) @xmath176 , ( c ) @xmath177 , ( d ) @xmath178 , ( e ) @xmath179 , and ( f ) @xmath180 . the inset of ( a ) illustrates the same for a larger magnetic - field range @xmath181 . the filling factor @xmath182 is plotted versus the magnetic field in ( f ) ; and several integer - number positions of @xmath182 are also marked in ( d ) and ( e ) . here the surface electron density @xmath169 and the lattice temperature @xmath183.,scaledwidth=47.0% ] in the following we will give more detailed examination on the linearly increasing magnetoresistance in the positive @xmath30 case . fig.[rhob ] shows the calculated resistivity @xmath157 versus the magnetic field strength @xmath2 at lattice temperature @xmath183 for system of carrier sheet density @xmath169 and @xmath173 , having different zero - field mobility @xmath184 and @xmath180 . all resistivity curves for mobility @xmath185 exhibit clear linearity in the magnetic - field range and appear no tendency of saturation at the highest field shown in the figure . especially , for the case @xmath170 , the linear behavior extends even up to the magnetic field of @xmath186 , as illustrated in the inset of fig.[rhob](a ) . this feature contradicts the classical mr which saturates at sufficiently large magnetic field @xmath187 . note that here we only present the calculated @xmath157 for magnetic field @xmath2 larger than @xmath188 t , for which a sufficient energy gap @xmath135 is assumed to open that with further increase of the magnetic field the states in the `` + ' ' -branch levels no longer shrink into the zero level and thus it should be excluded from the conduction band . this is of course not true for very weak magnetic field . when @xmath189 the energy gap @xmath190 , the situation becomes similar to the case of @xmath131 : the whole upper half of the zero - level states are available to electron occupation and we should have a flat resistivity @xmath157 when changing magnetic field . with increasing @xmath2 the portion of the zero - level states available to conduction electrons decreases until the magnetic field reaches @xmath191 . as a result the resistivity @xmath157 should exhibit a crossover from a flat changing at small @xmath2 to positively linear increasing at @xmath192 . this is just the behavior observed in the ti bi@xmath0se@xmath1.@xcite note that in the case of @xmath170 , the broadened landau - level widths are always larger than the neighboring level interval : @xmath193 , which requires @xmath194 ^ 2 $ ] , even for the lowest landau level @xmath195 , i.e. the whole landau - level spectrum is smeared . with increasing the zero - field mobility the magnitude of resistivity @xmath157 decreases , and when the broadened landau - level width becomes smaller than the neighboring level interval , @xmath196 , a weak sdh oscillation begin to occur around the linearly - dependent average value of @xmath157 at higher portion of the magnetic field range , as seen in fig.[rhob](c ) , ( d ) and ( e ) for @xmath197 and @xmath198 . on the other hand , in the case of large mobility , e.g. @xmath199 , where the broadened landau - level widths @xmath200 are much smaller than the neighboring level interval even for level index @xmath120 as large as @xmath201 , the magnetoresistivity shows pronounced sdh oscillation and the linear - dependent behavior disappears , before the appearance of quantum hall effect,@xcite as shown in fig.[rhob](f ) . abrikosov s model for the lmr requires the applied magnetic field large enough to reach the quantum limit at which all the carriers are within the lowest landau level,@xcite while it is obvious that more than one landau levels are occupied in the experimental samples in the field range in which the linear and non - saturating magnetoresistivity was observed.@xcite for the given electron surface density @xmath202 , the number of occupied landau levels , or the filling factor @xmath172 , at different magnetic fields is shown in fig.[rhob](f ) , as well as in the fig.[rhob](d ) and ( e ) , where the integer - number positions of @xmath203 , i.e. filling up to entire @xmath182 landau levels , coincide with the minima of the density - of - states or the dips of sdh oscillation . this is in contrast with @xmath131 case , where the integer number of @xmath203 , which implies a filling up to the center position of the @xmath182th landau levels , locates at a peak of sdh oscillation , as shown in fig.[diffg]b . the observed sdh oscillations in the bi@xmath0se@xmath1 nanoribbon exhibiting nonsaturating surface lmr in the experiment@xcite favor the former case : a finite positive effective @xmath133 . is plotted as a function of the surface electron density @xmath33 at magnetic field @xmath204 : ( a ) at different values of zero - field mobility @xmath5 , and ( b ) at different values of zero - field conductivity @xmath205.,scaledwidth=40.0% ] at various lattice temperatures . here the zero - magnetic - field mobility at zero temperature is @xmath206.,scaledwidth=35.0% ] next , we examine the density - dependence of the linear magnetoresistivity . to compare with abrikosov s quantum magnetoresistance which suggests a @xmath207 behavior,@xcite we show the calculated @xmath208 for above lmr versus the carrier sheet density @xmath33 in fig.[rhon ] at fixed magnetic field @xmath209 t . the mobility is taken respectively to be @xmath210 and @xmath211m@xmath212/vs to make the resistivity in the lmr regime . a clearly linear dependence of @xmath213 on the surface density @xmath33 is seen in all cases , indicating that this non - saturating linear resistivity is almost inversely proportional to the carrier density . in the figure we also show @xmath208 versus @xmath33 under the condition of different given conductivity @xmath214 and @xmath215 . in this case the half - width @xmath216 is independent of surface density . the linear dependence still holds , indicating that this linear behavior is not sensitive to the modest @xmath33-dependence of landau level broadening @xmath216 as long as the system is in the overlapped landau level regime . from the above discussion , it is obvious that lmr shows up in the system having overlapped landau levels and the separation of landau levels makes the mr departure from the linear increase . at high temperature , the thermal energy would smear the level separation and phonon scatterings further broaden landau levels . hence , it is believed that this lmr will be robust against raising temperature . this is indeed the case as seen in fig.[rhot ] , where we plot the calculated magnetoresistivity @xmath157 for the above system with zero - temperature linear mobility @xmath217m@xmath212/vs versus the magnetic field at different lattice temperatures . we can see that raising temperature to room temperature has little effect on the linearity of mr . due to the decreased mobility at higher temperature from phonon scattering , the weak sdh oscillation on the linear background tends to vanish . these features are in good agreement with the experimental report.@xcite in summary , we have studied the two - dimensional magnetotransport in the flat surface of a three - dimensional ti , which arises from the surface states with a wavevector - linear energy dispersion and a finite , positive zeeman splitting within the bulk energy gap . when the level broadening is comparable to or larger than the landau - level separation and the conduction electrons spread over many landau levels , a positive , dominantly linear and non - saturating magnetoresistance appears within a quite wide range of magnetic field and persists up to room temperature . this remarkable lmr provides a possible mechanism for the recently observed linear magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons.@xcite in contrast to quantum hall effect which appears in the case of well formed landau levels and to abrikosov s quantum magnetotransport,@xcite which is limited to the extreme quantum limit that all electrons coalesce into the lowest landau level , the discussed lmr is a phenomena of pure classical two - dimensional magnetotransport in a system having linear - energy - dispersion , appearing in the regime of overlapped landau levels , irrespective of its showing up in relatively high magnetic field range . furthermore , the present scheme deals with spatially uniform case without invoking the mobility fluctuation in a strongly inhomogeneous system , which is required in the classical parish and littlewood model to produce a lmr.@xcite the appearance of this significant positive - increasing linear magnetoresistance depends on the existence of a positive and sizable effective g - factor . if the zeeman energy splitting is quite small the resistivity @xmath157 would exhibit little change with changing magnetic field . in the case of a negative and sizable effective g - factor the magnetoresistivity would decrease linearly with increasing magnetic field . therefore , the behavior of the longitudinal resistivity versus magnetic field may provide a useful way for judging the direction and the size of the effective zeeman energy splitting in ti surface states . this work was supported by the national science foundation of china ( grant no . 11104002 ) , the national basic research program of china ( grant no . 2012cb927403 ) and by the program for science&technology innovation talents in universities of henan province ( grant no . 2012hastit029 ) .""" + + inputs = tokenizer( + [ARTICLE_LEP, ARTICLE_MAGNET], + max_length=1024, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + inputs = {k: inputs[k].to(torch_device) for k in inputs} + + hypotheses_batch = model.generate(**inputs) + + EXPECTED_LEP = "motivated by some recent studies on the light cp - odd higgs boson @xmath0 in non - minimal supersymmetric models, we investigate the rare @xmath1-decays @xmath2 ( @xmath3 ) in the two higgs doublet model ( 2hdm ), the nearly minimal supersymmetric standard model ( nmssm ), the next - to - minimal supersymmetric standard model ( nmssm ) and the minimal supersymmetric standard model ( mssm ). we find that the branching ratios of @xmath4 can reach @xmath5 in 2hdm, @xmath6 in nmssm and @xmath7 in mssm, which are at the level of @xmath8 in 2hdm, @xmath9 in nmssm and @xmath10 in mssm, respectively. these rates can be significantly enhanced in new physics models which lie within the expected sensitivity of the gigaz option of the international linear collider ( ilc ). = # 1,nucl. phys. b * # 1" + + EXPECTED_MAGNET = "a positive, nonsaturating and dominantly linear magnetoresistance can appear within quite wide magnetic - field range in the surface state of a topological insulator having a positive and finite effective g - factor. this linear magnetoresistance shows up in the system of high carrier concentration and low mobility when electrons are in extended states and spread over many smeared landau levels, and persists up to room temperature, providing a possible mechanism for the recently observed linear magnetoresistance in topological insulator bi@xmath0se@xmath1 nanoribbons." + + generated = tokenizer.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + + self.assertTrue(generated == [EXPECTED_LEP, EXPECTED_MAGNET]) + + +class BigBirdPegasusStandaloneDecoderModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=7, + d_model=32, + decoder_seq_length=7, + is_training=True, + is_decoder=True, + use_attention_mask=True, + use_cache=False, + use_labels=True, + decoder_start_token_id=2, + decoder_ffn_dim=32, + decoder_layers=4, + encoder_attention_heads=4, + decoder_attention_heads=4, + max_position_embeddings=30, + is_encoder_decoder=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + scope=None, + attention_type="original_full", + use_bias=True, + block_size=16, + num_random_blocks=3, + ): + self.parent = parent + self.batch_size = batch_size + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = decoder_layers + self.decoder_layers = decoder_layers + self.decoder_ffn_dim = decoder_ffn_dim + self.encoder_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.num_attention_heads = decoder_attention_heads + self.eos_token_id = eos_token_id + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.decoder_start_token_id = decoder_start_token_id + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.is_encoder_decoder = is_encoder_decoder + + self.scope = None + self.decoder_key_length = decoder_seq_length + self.base_model_out_len = 2 + self.decoder_attention_idx = 1 + + self.attention_type = attention_type + self.use_bias = use_bias + self.block_size = block_size + self.num_random_blocks = num_random_blocks + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = BigBirdPegasusConfig( + vocab_size=self.vocab_size, + d_model=self.d_model, + decoder_layers=self.decoder_layers, + decoder_ffn_dim=self.decoder_ffn_dim, + encoder_attention_heads=self.encoder_attention_heads, + decoder_attention_heads=self.decoder_attention_heads, + eos_token_id=self.eos_token_id, + bos_token_id=self.bos_token_id, + use_cache=self.use_cache, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + max_position_embeddings=self.max_position_embeddings, + is_encoder_decoder=self.is_encoder_decoder, + attention_type=self.attention_type, + use_bias=self.use_bias, + block_size=self.block_size, + num_random_blocks=self.num_random_blocks, + ) + + return ( + config, + input_ids, + attention_mask, + lm_labels, + ) + + def create_and_check_decoder_model_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + config.use_cache = True + model = BigBirdPegasusDecoder(config=config).to(torch_device).eval() + # first forward pass + outputs = model(input_ids, use_cache=True) + outputs_use_cache_conf = model(input_ids) + outputs_no_past = model(input_ids, use_cache=False) + + self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) + self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) + + past_key_values = outputs["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3) + + def create_and_check_decoder_model_attention_mask_past( + self, + config, + input_ids, + attention_mask, + lm_labels, + ): + model = BigBirdPegasusDecoder(config=config).to(torch_device).eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"] + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], + dim=1, + ) + + # get two different outputs + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + # big bird has extremely high logits which requires + # such a high error tolerance here + assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=5e-1) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, lm_labels = config_and_inputs + + inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} + return config, inputs_dict + + +@require_torch +class BigBirdPegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (BigBirdPegasusDecoder, BigBirdPegasusForCausalLM) if is_torch_available() else () + all_generative_model_classes = (BigBirdPegasusForCausalLM,) if is_torch_available() else () + test_pruning = False + is_encoder_decoder = False + + def setUp( + self, + ): + self.model_tester = BigBirdPegasusStandaloneDecoderModelTester(self, is_training=False) + self.config_tester = ConfigTester(self, config_class=BigBirdPegasusConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past(*config_and_inputs) + + def test_decoder_model_attn_mask_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) + + def test_retain_grad_hidden_states_attentions(self): + # decoder cannot keep gradients + return diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a98d406d2f9..19469075adc 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1043,7 +1043,6 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as temp_dir_name: model.base_model.save_pretrained(temp_dir_name) model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True) - with self.subTest(msg=f"Missing keys for {model.__class__.__name__}"): self.assertGreater(len(loading_info["missing_keys"]), 0) diff --git a/tests/test_tokenization_pegasus.py b/tests/test_tokenization_pegasus.py index c9ee3ee09e1..0db2d34cd7f 100644 --- a/tests/test_tokenization_pegasus.py +++ b/tests/test_tokenization_pegasus.py @@ -55,7 +55,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): raw_input_str = "Let's see which is the better one It seems like this was important " rust_ids = rust_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0] py_ids = py_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0] - # TODO: (Thom, Patrick) - this fails because the rust tokenizer does not know about the , , and those yet self.assertListEqual(py_ids, rust_ids) def test_large_mask_tokens(self): @@ -96,3 +95,81 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): assert batch.attention_mask.shape == (2, 1024) assert targets["input_ids"].shape == (2, 5) assert len(batch) == 2 # input_ids, attention_mask. + + +@require_sentencepiece +@require_tokenizers +class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = PegasusTokenizer + rust_tokenizer_class = PegasusTokenizerFast + test_rust_tokenizer = True + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = PegasusTokenizer(SAMPLE_VOCAB, offset=0, mask_token_sent=None, mask_token="[MASK]") + tokenizer.save_pretrained(self.tmpdirname) + + @cached_property + def _large_tokenizer(self): + return PegasusTokenizer.from_pretrained("google/bigbird-pegasus-large-arxiv") + + def get_tokenizer(self, **kwargs) -> PegasusTokenizer: + return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs) + + def get_input_output_texts(self, tokenizer): + return ("This is a test", "This is a test") + + def test_mask_tokens_rust_pegasus(self): + rust_tokenizer = self.rust_tokenizer_class.from_pretrained(self.tmpdirname) + py_tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname) + raw_input_str = "Let's see which is the better one [MASK] It seems like this [MASK] was important " + rust_ids = rust_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0] + py_ids = py_tokenizer([raw_input_str], return_tensors=None, add_special_tokens=False).input_ids[0] + self.assertListEqual(py_ids, rust_ids) + + @require_torch + def test_large_seq2seq_truncation(self): + src_texts = ["This is going to be way too long." * 1000, "short example"] + tgt_texts = ["not super long but more than 5 tokens", "tiny"] + batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt") + with self._large_tokenizer.as_target_tokenizer(): + targets = self._large_tokenizer( + tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt" + ) + + assert batch.input_ids.shape == (2, 4096) + assert batch.attention_mask.shape == (2, 4096) + assert targets["input_ids"].shape == (2, 5) + assert len(batch) == 2 # input_ids, attention_mask. + + def test_equivalence_to_orig_tokenizer(self): + """ + To run with original TF tokenizer: + + !wget https://github.com/google-research/bigbird/raw/master/bigbird/vocab/pegasus.model + !pip install tensorflow-text + + import tensorflow.compat.v2 as tf + import tensorflow_text as tft + + VOCAB_FILE = "./pegasus.model" + + tf.enable_v2_behavior() + + test_str = "This is an example string that is used to test the original TF implementation against the HF implementation" + tokenizer = tft.SentencepieceTokenizer(model=tf.io.gfile.GFile(VOCAB_FILE, "rb").read()) + + tokenizer.tokenize(test_str) + """ + + test_str = "This is an example string that is used to test the original TF implementation against the HF implementation" + + token_ids = self._large_tokenizer(test_str).input_ids + + self.assertListEqual( + token_ids, + [182, 117, 142, 587, 4211, 120, 117, 263, 112, 804, 109, 856, 25016, 3137, 464, 109, 26955, 3137, 1], + ) diff --git a/utils/check_repo.py b/utils/check_repo.py index c368ddd5b2e..0077fcc7e6b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -35,6 +35,9 @@ PATH_TO_DOC = "docs/source" # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = [ # models to ignore for not tested + "BigBirdPegasusEncoder", # Building part of bigger (tested) model. + "BigBirdPegasusDecoder", # Building part of bigger (tested) model. + "BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model. "M2M100Encoder", # Building part of bigger (tested) model. "M2M100Decoder", # Building part of bigger (tested) model. "Speech2TextEncoder", # Building part of bigger (tested) model.