diff --git a/README.md b/README.md index 1995e92665c..fce047a1fa1 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. 1. **[BART](https://huggingface.co/transformers/model_doc/bart.html)** (from Facebook) released with the paper [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/pdf/1910.13461.pdf) by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer. 1. **[BARThez](https://huggingface.co/transformers/model_doc/barthez.html)** (from École polytechnique) released with the paper [BARThez: a Skilled Pretrained French Sequence-to-Sequence Model](https://arxiv.org/abs/2010.12321) by Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis. +1. **[BEiT](https://huggingface.co/transformers/master/model_doc/beit.html)** (from Microsoft) released with the paper [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254) by Hangbo Bao, Li Dong, Furu Wei. 1. **[BERT](https://huggingface.co/transformers/model_doc/bert.html)** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. 1. **[BERT For Sequence Generation](https://huggingface.co/transformers/model_doc/bertgeneration.html)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[BigBird-RoBERTa](https://huggingface.co/transformers/model_doc/bigbird.html)** (from Google Research) released with the paper [Big Bird: Transformers for Longer Sequences](https://arxiv.org/abs/2007.14062) by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. diff --git a/docs/source/index.rst b/docs/source/index.rst index c6dad77df6b..8bf77a679f6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -105,190 +105,193 @@ Supported models 3. :doc:`BARThez ` (from École polytechnique) released with the paper `BARThez: a Skilled Pretrained French Sequence-to-Sequence Model `__ by Moussa Kamal Eddine, Antoine J.-P. Tixier, Michalis Vazirgiannis. -4. :doc:`BERT ` (from Google) released with the paper `BERT: Pre-training of Deep Bidirectional +4. `BEiT `__ (from Microsoft) released with the paper + `BEiT: BERT Pre-Training of Image Transformers `__ by Hangbo Bao, Li Dong, Furu + Wei. +5. :doc:`BERT ` (from Google) released with the paper `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding `__ by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. -5. :doc:`BERT For Sequence Generation ` (from Google) released with the paper `Leveraging +6. :doc:`BERT For Sequence Generation ` (from Google) released with the paper `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks `__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. -6. :doc:`BigBird-RoBERTa ` (from Google Research) released with the paper `Big Bird: Transformers +7. :doc:`BigBird-RoBERTa ` (from Google Research) released with the paper `Big Bird: Transformers for Longer Sequences `__ by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. -7. :doc:`BigBird-Pegasus ` (from Google Research) released with the paper `Big Bird: +8. :doc:`BigBird-Pegasus ` (from Google Research) released with the paper `Big Bird: Transformers for Longer Sequences `__ by Manzil Zaheer, Guru Guruganesh, Avinava Dubey, Joshua Ainslie, Chris Alberti, Santiago Ontanon, Philip Pham, Anirudh Ravula, Qifan Wang, Li Yang, Amr Ahmed. -8. :doc:`Blenderbot ` (from Facebook) released with the paper `Recipes for building an +9. :doc:`Blenderbot ` (from Facebook) released with the paper `Recipes for building an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -9. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building an - open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary - Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -10. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT +10. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building + an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, + Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. +11. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT `__ by Adrian de Wynter and Daniel J. Perry. -11. :doc:`ByT5 ` (from Google Research) released with the paper `ByT5: Towards a token-free future with +12. :doc:`ByT5 ` (from Google Research) released with the paper `ByT5: Towards a token-free future with pre-trained byte-to-byte models `__ by Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, Colin Raffel. -12. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty +13. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty French Language Model `__ by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. -13. :doc:`CANINE ` (from Google Research) released with the paper `CANINE: Pre-training an Efficient +14. :doc:`CANINE ` (from Google Research) released with the paper `CANINE: Pre-training an Efficient Tokenization-Free Encoder for Language Representation `__ by Jonathan H. Clark, Dan Garrette, Iulia Turc, John Wieting. -14. :doc:`CLIP ` (from OpenAI) released with the paper `Learning Transferable Visual Models From +15. :doc:`CLIP ` (from OpenAI) released with the paper `Learning Transferable Visual Models From Natural Language Supervision `__ by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. -15. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with +16. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with Span-based Dynamic Convolution `__ by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. -16. :doc:`CPM ` (from Tsinghua University) released with the paper `CPM: A Large-scale Generative +17. :doc:`CPM ` (from Tsinghua University) released with the paper `CPM: A Large-scale Generative Chinese Pre-trained Language Model `__ by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. -17. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language +18. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language Model for Controllable Generation `__ by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. -18. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with +19. :doc:`DeBERTa ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -19. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT +20. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -20. :doc:`DeiT ` (from Facebook) released with the paper `Training data-efficient image transformers & +21. :doc:`DeiT ` (from Facebook) released with the paper `Training data-efficient image transformers & distillation through attention `__ by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. -21. :doc:`DETR ` (from Facebook) released with the paper `End-to-End Object Detection with Transformers +22. :doc:`DETR ` (from Facebook) released with the paper `End-to-End Object Detection with Transformers `__ by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov, Sergey Zagoruyko. -22. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +23. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -23. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +24. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -24. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +25. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -25. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +26. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -26. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +27. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -27. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +28. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -28. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +29. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -29. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +30. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -30. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo +31. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. -31. :doc:`Hubert ` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech +32. :doc:`Hubert ` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units `__ by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed. -32. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +33. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -33. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +34. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -34. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +35. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -35. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +36. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -36. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity +37. :doc:`LUKE ` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention `__ by Ikuya Yamada, Akari Asai, Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto. -37. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +38. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -38. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +39. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -39. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +40. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -40. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +41. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -41. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +42. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -42. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +43. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -43. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +44. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -44. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +45. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -45. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +46. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -46. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +47. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -47. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +48. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -48. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +49. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -49. :doc:`RemBERT ` (from Google Research) released with the paper `Rethinking embedding coupling in +50. :doc:`RemBERT ` (from Google Research) released with the paper `Rethinking embedding coupling in pre-trained language models `__ by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder. -50. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +51. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -51. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: +52. :doc:`RoFormer ` (from ZhuiyiTechnology), released together with the paper a `RoFormer: Enhanced Transformer with Rotary Position Embedding `__ by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu. -52. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +53. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -53. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +54. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -54. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +55. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -55. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +56. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -56. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +57. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -57. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +58. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -58. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and +59. :doc:`VisualBERT ` (from UCLA NLP) released with the paper `VisualBERT: A Simple and Performant Baseline for Vision and Language `__ by Liunian Harold Li, Mark Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang. -59. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +60. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -60. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +61. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -61. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +62. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -62. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +63. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -63. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +64. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -64. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +65. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -314,6 +317,8 @@ Flax), PyTorch, and/or TensorFlow. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | BERT | ✅ | ✅ | ✅ | ✅ | ✅ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| BeiT | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | Bert Generation | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | BigBird | ✅ | ✅ | ✅ | ❌ | ✅ | @@ -508,6 +513,7 @@ Flax), PyTorch, and/or TensorFlow. model_doc/auto model_doc/bart model_doc/barthez + model_doc/beit model_doc/bert model_doc/bertweet model_doc/bertgeneration diff --git a/docs/source/model_doc/beit.rst b/docs/source/model_doc/beit.rst new file mode 100644 index 00000000000..ad6c6795b8b --- /dev/null +++ b/docs/source/model_doc/beit.rst @@ -0,0 +1,97 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +BEiT +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The BEiT model was proposed in `BEiT: BERT Pre-Training of Image Transformers `__ by +Hangbo Bao, Li Dong and Furu Wei. Inspired by BERT, BEiT is the first paper that makes self-supervised pre-training of +Vision Transformers (ViTs) outperform supervised pre-training. Rather than pre-training the model to predict the class +of an image (as done in the `original ViT paper `__), BEiT models are pre-trained to +predict visual tokens from the codebook of OpenAI's `DALL-E model `__ given masked +patches. + +The abstract from the paper is the following: + +*We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation +from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image +modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image +patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into +visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training +objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we +directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. +Experimental results on image classification and semantic segmentation show that our model achieves competitive results +with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, +significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains +86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).* + +Tips: + +- BEiT models are regular Vision Transformers, but pre-trained in a self-supervised way rather than supervised. They + outperform both the original model (ViT) as well as Data-efficient Image Transformers (DeiT) when fine-tuned on + ImageNet-1K and CIFAR-100. +- As the BEiT models expect each image to be of the same size (resolution), one can use + :class:`~transformers.BeitFeatureExtractor` to resize (or rescale) and normalize images for the model. +- Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of + each checkpoint. For example, :obj:`microsoft/beit-base-patch16-224` refers to a base-sized architecture with patch + resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the `hub + `__. +- The available checkpoints are either (1) pre-trained on `ImageNet-22k `__ (a collection of + 14 million images and 22k classes) only, (2) also fine-tuned on ImageNet-22k or (3) also fine-tuned on `ImageNet-1k + `__ (also referred to as ILSVRC 2012, a collection of 1.3 million + images and 1,000 classes). +- BEiT uses relative position embeddings, inspired by the T5 model. During pre-training, the authors shared the + relative position bias among the several self-attention layers. During fine-tuning, each layer's relative position + bias is initialized with the shared relative position bias obtained after pre-training. Note that, if one wants to + pre-train a model from scratch, one needs to either set the :obj:`use_relative_position_bias` or the + :obj:`use_relative_position_bias` attribute of :class:`~transformers.BeitConfig` to :obj:`True` in order to add + position embeddings. + +This model was contributed by `nielsr `__. The original code can be found `here +`__. + +BeitConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitConfig + :members: + + +BeitFeatureExtractor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitFeatureExtractor + :members: __call__ + + +BeitModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitModel + :members: forward + + +BeitForMaskedImageModeling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitForMaskedImageModeling + :members: forward + + +BeitForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.BeitForImageClassification + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4b495aad3d0..e776393e37a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -147,6 +147,7 @@ _import_structure = { ], "models.bart": ["BartConfig", "BartTokenizer"], "models.barthez": [], + "models.beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"], "models.bert": [ "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BasicTokenizer", @@ -412,6 +413,7 @@ else: # Vision-specific objects if is_vision_available(): _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + _import_structure["models.beit"].append("BeitFeatureExtractor") _import_structure["models.clip"].append("CLIPFeatureExtractor") _import_structure["models.clip"].append("CLIPProcessor") _import_structure["models.deit"].append("DeiTFeatureExtractor") @@ -510,7 +512,6 @@ if is_torch_available(): "load_tf_weights_in_albert", ] ) - _import_structure["models.auto"].extend( [ "MODEL_FOR_CAUSAL_LM_MAPPING", @@ -542,7 +543,6 @@ if is_torch_available(): "AutoModelWithLMHead", ] ) - _import_structure["models.bart"].extend( [ "BART_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -555,6 +555,15 @@ if is_torch_available(): "PretrainedBartModel", ] ) + _import_structure["models.beit"].extend( + [ + "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "BeitForImageClassification", + "BeitForMaskedImageModeling", + "BeitModel", + "BeitPreTrainedModel", + ] + ) _import_structure["models.bert"].extend( [ "BERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1814,6 +1823,7 @@ if TYPE_CHECKING: AutoTokenizer, ) from .models.bart import BartConfig, BartTokenizer + from .models.beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig from .models.bert import ( BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BasicTokenizer, @@ -2049,6 +2059,7 @@ if TYPE_CHECKING: if is_vision_available(): from .image_utils import ImageFeatureExtractionMixin + from .models.beit import BeitFeatureExtractor from .models.clip import CLIPFeatureExtractor, CLIPProcessor from .models.deit import DeiTFeatureExtractor from .models.detr import DetrFeatureExtractor @@ -2171,6 +2182,13 @@ if TYPE_CHECKING: BartPretrainedModel, PretrainedBartModel, ) + from .models.beit import ( + BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, + BeitForImageClassification, + BeitForMaskedImageModeling, + BeitModel, + BeitPreTrainedModel, + ) from .models.bert import ( BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertForMaskedLM, diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index add2ccac8d1..58846b32d67 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -21,6 +21,8 @@ from .file_utils import _is_torch, is_torch_available IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] +IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] +IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] def is_torch_tensor(obj): diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7bfa22953ff..0d01a7680f3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -21,6 +21,7 @@ from . import ( auto, bart, barthez, + beit, bert, bert_generation, bert_japanese, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index d3f3d2aa005..aa7ccaa1632 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -20,6 +20,7 @@ from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig +from ..beit.configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from ..bert_generation.configuration_bert_generation import BertGenerationConfig from ..big_bird.configuration_big_bird import BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdConfig @@ -97,6 +98,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( (key, value) for pretrained_map in [ # Add archive maps here + BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, VISUAL_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CANINE_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -158,6 +160,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("beit", BeitConfig), ("rembert", RemBertConfig), ("visual_bert", VisualBertConfig), ("canine", CanineConfig), @@ -225,6 +228,7 @@ CONFIG_MAPPING = OrderedDict( MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("beit", "BeiT"), ("rembert", "RemBERT"), ("visual_bert", "VisualBert"), ("canine", "Canine"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 6c6a3a70511..6d853a131a1 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -17,9 +17,9 @@ import os from collections import OrderedDict -from transformers import DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor +from transformers import BeitFeatureExtractor, DeiTFeatureExtractor, Speech2TextFeatureExtractor, ViTFeatureExtractor -from ... import DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config +from ... import BeitConfig, DeiTConfig, PretrainedConfig, Speech2TextConfig, ViTConfig, Wav2Vec2Config from ...feature_extraction_utils import FeatureExtractionMixin # Build the list of all feature extractors @@ -30,6 +30,7 @@ from .configuration_auto import AutoConfig, replace_list_option_in_docstrings FEATURE_EXTRACTOR_MAPPING = OrderedDict( [ + (BeitConfig, BeitFeatureExtractor), (DeiTConfig, DeiTFeatureExtractor), (Speech2TextConfig, Speech2TextFeatureExtractor), (ViTConfig, ViTFeatureExtractor), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 44967ed548a..92b4d132568 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -37,6 +37,7 @@ from ..bart.modeling_bart import ( BartForSequenceClassification, BartModel, ) +from ..beit.modeling_beit import BeitForImageClassification, BeitModel from ..bert.modeling_bert import ( BertForMaskedLM, BertForMultipleChoice, @@ -321,6 +322,7 @@ from .auto_factory import _BaseAutoModelClass, auto_class_update from .configuration_auto import ( AlbertConfig, BartConfig, + BeitConfig, BertConfig, BertGenerationConfig, BigBirdConfig, @@ -388,6 +390,7 @@ logger = logging.get_logger(__name__) MODEL_MAPPING = OrderedDict( [ # Base model mapping + (BeitConfig, BeitModel), (RemBertConfig, RemBertModel), (VisualBertConfig, VisualBertModel), (CanineConfig, CanineModel), @@ -579,6 +582,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( # Model for Image Classification mapping (ViTConfig, ViTForImageClassification), (DeiTConfig, (DeiTForImageClassification, DeiTForImageClassificationWithTeacher)), + (BeitConfig, BeitForImageClassification), ] ) diff --git a/src/transformers/models/beit/__init__.py b/src/transformers/models/beit/__init__.py new file mode 100644 index 00000000000..0ca6ddd27cb --- /dev/null +++ b/src/transformers/models/beit/__init__.py @@ -0,0 +1,59 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...file_utils import _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_beit": ["BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BeitConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_beit"] = ["BeitFeatureExtractor"] + +if is_torch_available(): + _import_structure["modeling_beit"] = [ + "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "BeitForImageClassification", + "BeitForMaskedImageModeling", + "BeitModel", + "BeitPreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_beit import BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, BeitConfig + + if is_vision_available(): + from .feature_extraction_beit import BeitFeatureExtractor + + if is_torch_available(): + from .modeling_beit import ( + BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, + BeitForImageClassification, + BeitForMaskedImageModeling, + BeitModel, + BeitPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py new file mode 100644 index 00000000000..08ecc606469 --- /dev/null +++ b/src/transformers/models/beit/configuration_beit.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BEiT model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/beit-base-patch16-224-in22k": "https://huggingface.co/microsoft/beit-base-patch16-224-in22k/resolve/main/config.json", + # See all BEiT models at https://huggingface.co/models?filter=beit +} + + +class BeitConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.BeitModel`. It is used to + instantiate an BEiT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the BEiT + `microsoft/beit-base-patch16-224-in22k `__ + architecture. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 8092): + Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during + pre-training. + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + image_size (:obj:`int`, `optional`, defaults to :obj:`224`): + The size (resolution) of each image. + patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): + The size (resolution) of each patch. + num_channels (:obj:`int`, `optional`, defaults to :obj:`3`): + The number of input channels. + use_mask_token (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use a mask token for masked image modeling. + use_absolute_position_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use BERT-style absolute position embeddings. + use_relative_position_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use T5-style relative position embeddings in the self-attention layers. + use_shared_relative_position_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use the same relative position embeddings across all self-attention layers of the Transformer. + layer_scale_init_value (:obj:`float`, `optional`, defaults to 0.1): + Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale. + drop_path_rate (:obj:`float`, `optional`, defaults to 0.1): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the + CLS token, before applying the classification head. + + Example:: + + >>> from transformers import BeitModel, BeitConfig + + >>> # Initializing a BEiT beit-base-patch16-224-in22k style configuration + >>> configuration = BeitConfig() + + >>> # Initializing a model from the beit-base-patch16-224-in22k style configuration + >>> model = BeitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "beit" + + def __init__( + self, + vocab_size=8192, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + is_encoder_decoder=False, + image_size=224, + patch_size=16, + num_channels=3, + use_mask_token=False, + use_absolute_position_embeddings=False, + use_relative_position_bias=False, + use_shared_relative_position_bias=False, + layer_scale_init_value=0.1, + drop_path_rate=0.1, + use_mean_pooling=True, + **kwargs + ): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.use_mask_token = use_mask_token + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_bias = use_relative_position_bias + self.use_shared_relative_position_bias = use_shared_relative_position_bias + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.use_mean_pooling = use_mean_pooling diff --git a/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py new file mode 100644 index 00000000000..c550a56db36 --- /dev/null +++ b/src/transformers/models/beit/convert_beit_unilm_to_pytorch.py @@ -0,0 +1,290 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert BEiT checkpoints from the unilm repository.""" + + +import argparse +import json +from pathlib import Path + +import torch +from PIL import Image + +import requests +from huggingface_hub import cached_download, hf_hub_url +from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, has_lm_head=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "beit.embeddings.cls_token"), + ("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), + ("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), + ] + ) + + if has_lm_head: + # mask token + shared relative position bias + layernorm + rename_keys.extend( + [ + ("mask_token", "beit.embeddings.mask_token"), + ( + "rel_pos_bias.relative_position_bias_table", + "beit.encoder.relative_position_bias.relative_position_bias_table", + ), + ( + "rel_pos_bias.relative_position_index", + "beit.encoder.relative_position_bias.relative_position_index", + ), + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + else: + # layernorm + classification head + rename_keys.extend( + [ + ("fc_norm.weight", "beit.pooler.layernorm.weight"), + ("fc_norm.bias", "beit.pooler.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, has_lm_head=False): + for i in range(config.num_hidden_layers): + prefix = "beit." + # queries, keys and values + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias") + v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias") + + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias + + # gamma_1 and gamma_2 + # we call them lambda because otherwise they are renamed when using .from_pretrained + gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1") + gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2") + + state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1 + state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2 + + # relative_position bias table + index + if not has_lm_head: + # each layer has its own relative position bias + table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table") + index = state_dict.pop(f"blocks.{i}.attn.relative_position_index") + + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" + ] = table + state_dict[ + f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" + ] = index + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our BEiT structure. + """ + + # define default BEiT configuration + config = BeitConfig() + has_lm_head = False + repo_id = "datasets/huggingface/label-files" + # set config parameters based on URL + if checkpoint_url[-9:-4] == "pt22k": + # masked image modeling + config.use_shared_relative_position_bias = True + config.use_mask_token = True + has_lm_head = True + elif checkpoint_url[-9:-4] == "ft22k": + # intermediate fine-tuning on ImageNet-22k + config.use_relative_position_bias = True + config.num_labels = 21841 + filename = "imagenet-22k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + # this dataset contains 21843 labels but the model only has 21841 + # we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18 + del id2label[9205] + del id2label[15027] + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + elif checkpoint_url[-8:-4] == "to1k": + # fine-tuning on ImageNet-1k + config.use_relative_position_bias = True + config.num_labels = 1000 + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + if "384" in checkpoint_url: + config.image_size = 384 + if "512" in checkpoint_url: + config.image_size = 512 + else: + raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'") + + # size of the architecture + if "base" in checkpoint_url: + pass + elif "large" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + else: + raise ValueError("Should either find 'base' or 'large' in checkpoint URL") + + # load state_dict of original model, remove and rename some keys + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"] + rename_keys = create_rename_keys(config, has_lm_head=has_lm_head) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head) + + # load HuggingFace model + if checkpoint_url[-9:-4] == "pt22k": + model = BeitForMaskedImageModeling(config) + else: + model = BeitForImageClassification(config) + model.eval() + model.load_state_dict(state_dict) + + # Check outputs on an image + feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) + encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + + outputs = model(pixel_values) + logits = outputs.logits + + # verify logits + expected_shape = torch.Size([1, 1000]) + if checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k"): + expected_shape = torch.Size([1, 196, 8192]) + elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k"): + expected_shape = torch.Size([1, 196, 8192]) + elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22k"): + expected_shape = torch.Size([1, 21841]) + expected_logits = torch.tensor([2.2288, 2.4671, 0.7395]) + expected_class_idx = 2397 + elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22k"): + expected_shape = torch.Size([1, 21841]) + expected_logits = torch.tensor([1.6881, -0.2787, 0.5901]) + expected_class_idx = 2396 + elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft1k"): + expected_logits = torch.tensor([0.1241, 0.0798, -0.6569]) + expected_class_idx = 285 + elif checkpoint_url[:-4].endswith("beit_base_patch16_224_pt22k_ft22kto1k"): + expected_logits = torch.tensor([-1.2385, -1.0987, -1.0108]) + expected_class_idx = 281 + elif checkpoint_url[:-4].endswith("beit_base_patch16_384_pt22k_ft22kto1k"): + expected_logits = torch.tensor([-1.5303, -0.9484, -0.3147]) + expected_class_idx = 761 + elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft1k"): + expected_logits = torch.tensor([0.4610, -0.0928, 0.2086]) + expected_class_idx = 761 + elif checkpoint_url[:-4].endswith("beit_large_patch16_224_pt22k_ft22kto1k"): + expected_logits = torch.tensor([-0.4804, 0.6257, -0.1837]) + expected_class_idx = 761 + elif checkpoint_url[:-4].endswith("beit_large_patch16_384_pt22k_ft22kto1k"): + expected_logits = torch.tensor([[-0.5122, 0.5117, -0.2113]]) + expected_class_idx = 761 + elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): + expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) + expected_class_idx = 761 + else: + raise ValueError("Can't verify logits as model is not supported") + + assert logits.shape == expected_shape, "Shape of logits not as expected" + print("Shape of logits:", logits.shape) + if not has_lm_head: + print("Predicted class idx:", logits.argmax(-1).item()) + assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" + assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_url", + default="https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth", + type=str, + help="URL to the original PyTorch checkpoint (.pth file).", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + args = parser.parse_args() + convert_beit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py new file mode 100644 index 00000000000..4bca0a14c8b --- /dev/null +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for BEiT.""" + +from typing import List, Optional, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...file_utils import TensorType +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a BEiT feature extractor. + + This feature extractor inherits from :class:`~transformers.FeatureExtractionMixin` which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to resize the input to a certain :obj:`size`. + size (:obj:`int` or :obj:`Tuple(int)`, `optional`, defaults to 256): + Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an + integer is provided, then the input will be resized to (size, size). Only has an effect if :obj:`do_resize` + is set to :obj:`True`. + resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BICUBIC`): + An optional resampling filter. This can be one of :obj:`PIL.Image.NEAREST`, :obj:`PIL.Image.BOX`, + :obj:`PIL.Image.BILINEAR`, :obj:`PIL.Image.HAMMING`, :obj:`PIL.Image.BICUBIC` or :obj:`PIL.Image.LANCZOS`. + Only has an effect if :obj:`do_resize` is set to :obj:`True`. + do_center_crop (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to crop the input at the center. If the input size is smaller than :obj:`crop_size` along any edge, + the image is padded with 0's and then center cropped. + crop_size (:obj:`int`, `optional`, defaults to 224): + Desired output size when applying center-cropping. Only has an effect if :obj:`do_center_crop` is set to + :obj:`True`. + do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to normalize the input with :obj:`image_mean` and :obj:`image_std`. + image_mean (:obj:`List[int]`, defaults to :obj:`[0.5, 0.5, 0.5]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (:obj:`List[int]`, defaults to :obj:`[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=256, + resample=Image.BICUBIC, + do_center_crop=True, + crop_size=224, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def __call__( + self, + images: Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa + ], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + .. warning:: + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + Args: + images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`, defaults to :obj:`'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects. + * :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects. + + Returns: + :class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, + width). + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example)," + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + center cropping + normalization) + if self.do_resize and self.size is not None and self.resample is not None: + images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] + if self.do_center_crop and self.crop_size is not None: + images = [self.center_crop(image, self.crop_size) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py new file mode 100755 index 00000000000..0eefac892c9 --- /dev/null +++ b/src/transformers/models/beit/modeling_beit.py @@ -0,0 +1,842 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BEiT model. """ + + +import collections.abc +import math + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from .configuration_beit import BeitConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BeitConfig" +_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224" + +BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/beit-base-patch16-224", + # See all BEiT models at https://huggingface.co/models?filter=beit +] + + +# Inspired by +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# From PyTorch internals +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class BeitEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + + """ + + def __init__(self, config): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + else: + self.mask_token = None + self.patch_embeddings = PatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + if config.use_absolute_position_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + else: + self.position_embeddings = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values, bool_masked_pos=None): + + embeddings = self.patch_embeddings(pixel_values) + batch_size, seq_len, _ = embeddings.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1 - w) + mask_tokens * w + + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + # FIXME look at relaxing size constraints + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +class BeitSelfAttention(nn.Module): + def __init__(self, config, window_size=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + if window_size: + self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size) + else: + self.relative_position_bias = None + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Add relative position bias if present. + if self.relative_position_bias is not None: + attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + attention_scores = attention_scores + relative_position_bias + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class BeitSelfOutput(nn.Module): + """ + The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor, gamma=None): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class BeitAttention(nn.Module): + def __init__(self, config, window_size=None): + super().__init__() + self.attention = BeitSelfAttention(config, window_size=window_size) + self.output = BeitSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): + self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BeitIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class BeitOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class BeitLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config, window_size=None, drop_path_rate=0.0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BeitAttention(config, window_size=window_size) + self.intermediate = BeitIntermediate(config) + self.output = BeitOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + init_values = config.layer_scale_init_value + if init_values > 0: + self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True) + self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True) + else: + self.lambda_1, self.lambda_2 = None, None + + def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None): + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # apply lambda_1 if present + if self.lambda_1 is not None: + attention_output = self.lambda_1 * attention_output + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in BEiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + layer_output = self.intermediate(layer_output) + layer_output = self.output(layer_output) + + if self.lambda_2 is not None: + layer_output = self.lambda_2 * layer_output + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class BeitRelativePositionBias(nn.Module): + def __init__(self, config, window_size): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, config.num_attention_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1 + ) # Wh*Ww,Wh*Ww,nH + + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class BeitEncoder(nn.Module): + def __init__(self, config, window_size=None): + super().__init__() + self.config = config + if config.use_shared_relative_position_bias: + self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size) + else: + self.relative_position_bias = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + self.layer = nn.ModuleList( + [ + BeitLayer( + config, + window_size=window_size if config.use_relative_position_bias else None, + drop_path_rate=dpr[i], + ) + for i in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + hidden_states, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + relative_position_bias = ( + self.relative_position_bias() if self.relative_position_bias is not None else None + ) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class BeitPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BeitConfig + base_model_prefix = "beit" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +BEIT_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ subclass. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.BeitConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +BEIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using :class:`~transformers.BeitFeatureExtractor`. See + :meth:`transformers.BeitFeatureExtractor.__call__` for details. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", + BEIT_START_DOCSTRING, +) +class BeitModel(BeitPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BeitEmbeddings(config) + self.encoder = BeitEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape) + + self.layernorm = ( + nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + ) + self.pooler = BeitPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + bool_masked_pos=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples:: + + >>> from transformers import BeitFeatureExtractor, BeitModel + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k') + >>> model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class BeitPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.layernorm = ( + nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None + ) + + def forward(self, hidden_states): + if self.layernorm is not None: + # Mean pool the final hidden states of the patch tokens + patch_tokens = hidden_states[:, 1:, :] + pooled_output = self.layernorm(patch_tokens.mean(1)) + else: + # Pool by simply taking the final hidden state of the [CLS] token + pooled_output = hidden_states[:, 0] + + return pooled_output + + +@add_start_docstrings( + "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", BEIT_START_DOCSTRING +) +class BeitForMaskedImageModeling(BeitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=False) + + # Classifier head + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + self.init_weights() + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + bool_masked_pos=None, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k') + >>> model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + sequence_output = self.layernorm(sequence_output) + prediction_scores = self.lm_head(sequence_output[:, 1:]) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final + hidden states of the patch tokens) e.g. for ImageNet. + """, + BEIT_START_DOCSTRING, +) +class BeitForImageClassification(BeitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.beit = BeitModel(config, add_pooling_layer=True) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + self.init_weights() + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import BeitFeatureExtractor, BeitForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') + >>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/deit/convert_deit_timm_to_pytorch.py b/src/transformers/models/deit/convert_deit_timm_to_pytorch.py index f866b90a80d..dc6437f6eeb 100644 --- a/src/transformers/models/deit/convert_deit_timm_to_pytorch.py +++ b/src/transformers/models/deit/convert_deit_timm_to_pytorch.py @@ -16,6 +16,7 @@ import argparse +import json from pathlib import Path import torch @@ -23,9 +24,9 @@ from PIL import Image import requests import timm +from huggingface_hub import cached_download, hf_hub_url from transformers import DeiTConfig, DeiTFeatureExtractor, DeiTForImageClassificationWithTeacher from transformers.utils import logging -from transformers.utils.imagenet_classes import id2label logging.set_verbosity_info() @@ -139,6 +140,10 @@ def convert_deit_checkpoint(deit_name, pytorch_dump_folder_path): base_model = False # dataset (fine-tuned on ImageNet 2012), patch_size and image_size config.num_labels = 1000 + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.patch_size = int(deit_name[-6:-4]) diff --git a/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py index 66165809759..ded8037eb2d 100644 --- a/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py @@ -16,6 +16,7 @@ import argparse +import json from collections import OrderedDict from pathlib import Path @@ -23,9 +24,9 @@ import torch from PIL import Image import requests +from huggingface_hub import cached_download, hf_hub_url from transformers import DetrConfig, DetrFeatureExtractor, DetrForObjectDetection, DetrForSegmentation from transformers.utils import logging -from transformers.utils.coco_classes import id2label logging.set_verbosity_info() @@ -193,6 +194,10 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): config.num_labels = 250 else: config.num_labels = 91 + repo_id = "datasets/huggingface/label-files" + filename = "coco-detection-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index 88d75f6e403..98986a6bd36 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -16,6 +16,7 @@ import argparse +import json from pathlib import Path import torch @@ -23,9 +24,9 @@ from PIL import Image import requests import timm +from huggingface_hub import cached_download, hf_hub_url from transformers import DeiTFeatureExtractor, ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel from transformers.utils import logging -from transformers.utils.imagenet_classes import id2label logging.set_verbosity_info() @@ -146,6 +147,10 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): config.image_size = int(vit_name[-9:-6]) else: config.num_labels = 1000 + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r")) + id2label = {int(k): v for k, v in id2label.items()} config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} config.patch_size = int(vit_name[-6:-4]) diff --git a/src/transformers/models/vit/feature_extraction_vit.py b/src/transformers/models/vit/feature_extraction_vit.py index a5177a15b4b..f700088372c 100644 --- a/src/transformers/models/vit/feature_extraction_vit.py +++ b/src/transformers/models/vit/feature_extraction_vit.py @@ -21,7 +21,7 @@ from PIL import Image from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...file_utils import TensorType -from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor +from ...image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD, ImageFeatureExtractionMixin, is_torch_tensor from ...utils import logging @@ -71,8 +71,8 @@ class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): self.size = size self.resample = resample self.do_normalize = do_normalize - self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] - self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD def __call__( self, diff --git a/src/transformers/utils/coco_classes.py b/src/transformers/utils/coco_classes.py deleted file mode 100644 index cc540052aef..00000000000 --- a/src/transformers/utils/coco_classes.py +++ /dev/null @@ -1,94 +0,0 @@ -# COCO object detection id's to class names -id2label = { - 0: "N/A", - 1: "person", - 2: "bicycle", - 3: "car", - 4: "motorcycle", - 5: "airplane", - 6: "bus", - 7: "train", - 8: "truck", - 9: "boat", - 10: "traffic light", - 11: "fire hydrant", - 12: "N/A", - 13: "stop sign", - 14: "parking meter", - 15: "bench", - 16: "bird", - 17: "cat", - 18: "dog", - 19: "horse", - 20: "sheep", - 21: "cow", - 22: "elephant", - 23: "bear", - 24: "zebra", - 25: "giraffe", - 26: "N/A", - 27: "backpack", - 28: "umbrella", - 29: "N/A", - 30: "N/A", - 31: "handbag", - 32: "tie", - 33: "suitcase", - 34: "frisbee", - 35: "skis", - 36: "snowboard", - 37: "sports ball", - 38: "kite", - 39: "baseball bat", - 40: "baseball glove", - 41: "skateboard", - 42: "surfboard", - 43: "tennis racket", - 44: "bottle", - 45: "N/A", - 46: "wine glass", - 47: "cup", - 48: "fork", - 49: "knife", - 50: "spoon", - 51: "bowl", - 52: "banana", - 53: "apple", - 54: "sandwich", - 55: "orange", - 56: "broccoli", - 57: "carrot", - 58: "hot dog", - 59: "pizza", - 60: "donut", - 61: "cake", - 62: "chair", - 63: "couch", - 64: "potted plant", - 65: "bed", - 66: "N/A", - 67: "dining table", - 68: "N/A", - 69: "N/A", - 70: "toilet", - 71: "N/A", - 72: "tv", - 73: "laptop", - 74: "mouse", - 75: "remote", - 76: "keyboard", - 77: "cell phone", - 78: "microwave", - 79: "oven", - 80: "toaster", - 81: "sink", - 82: "refrigerator", - 83: "N/A", - 84: "book", - 85: "clock", - 86: "vase", - 87: "scissors", - 88: "teddy bear", - 89: "hair drier", - 90: "toothbrush", -} diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e15c454f4de..53346f1f1e8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -588,6 +588,41 @@ class PretrainedBartModel: requires_backends(cls, ["torch"]) +BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class BeitForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BeitForMaskedImageModeling: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BeitModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BeitPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index b03bc232538..6868ae65a73 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -7,6 +7,11 @@ class ImageFeatureExtractionMixin: requires_backends(self, ["vision"]) +class BeitFeatureExtractor: + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class CLIPFeatureExtractor: def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) diff --git a/src/transformers/utils/imagenet_classes.py b/src/transformers/utils/imagenet_classes.py deleted file mode 100644 index 73d831095c5..00000000000 --- a/src/transformers/utils/imagenet_classes.py +++ /dev/null @@ -1,1003 +0,0 @@ -# ImageNet 2012 id's to class names -id2label = { - 0: "tench, Tinca tinca", - 1: "goldfish, Carassius auratus", - 2: "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", - 3: "tiger shark, Galeocerdo cuvieri", - 4: "hammerhead, hammerhead shark", - 5: "electric ray, crampfish, numbfish, torpedo", - 6: "stingray", - 7: "cock", - 8: "hen", - 9: "ostrich, Struthio camelus", - 10: "brambling, Fringilla montifringilla", - 11: "goldfinch, Carduelis carduelis", - 12: "house finch, linnet, Carpodacus mexicanus", - 13: "junco, snowbird", - 14: "indigo bunting, indigo finch, indigo bird, Passerina cyanea", - 15: "robin, American robin, Turdus migratorius", - 16: "bulbul", - 17: "jay", - 18: "magpie", - 19: "chickadee", - 20: "water ouzel, dipper", - 21: "kite", - 22: "bald eagle, American eagle, Haliaeetus leucocephalus", - 23: "vulture", - 24: "great grey owl, great gray owl, Strix nebulosa", - 25: "European fire salamander, Salamandra salamandra", - 26: "common newt, Triturus vulgaris", - 27: "eft", - 28: "spotted salamander, Ambystoma maculatum", - 29: "axolotl, mud puppy, Ambystoma mexicanum", - 30: "bullfrog, Rana catesbeiana", - 31: "tree frog, tree-frog", - 32: "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", - 33: "loggerhead, loggerhead turtle, Caretta caretta", - 34: "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", - 35: "mud turtle", - 36: "terrapin", - 37: "box turtle, box tortoise", - 38: "banded gecko", - 39: "common iguana, iguana, Iguana iguana", - 40: "American chameleon, anole, Anolis carolinensis", - 41: "whiptail, whiptail lizard", - 42: "agama", - 43: "frilled lizard, Chlamydosaurus kingi", - 44: "alligator lizard", - 45: "Gila monster, Heloderma suspectum", - 46: "green lizard, Lacerta viridis", - 47: "African chameleon, Chamaeleo chamaeleon", - 48: "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", - 49: "African crocodile, Nile crocodile, Crocodylus niloticus", - 50: "American alligator, Alligator mississipiensis", - 51: "triceratops", - 52: "thunder snake, worm snake, Carphophis amoenus", - 53: "ringneck snake, ring-necked snake, ring snake", - 54: "hognose snake, puff adder, sand viper", - 55: "green snake, grass snake", - 56: "king snake, kingsnake", - 57: "garter snake, grass snake", - 58: "water snake", - 59: "vine snake", - 60: "night snake, Hypsiglena torquata", - 61: "boa constrictor, Constrictor constrictor", - 62: "rock python, rock snake, Python sebae", - 63: "Indian cobra, Naja naja", - 64: "green mamba", - 65: "sea snake", - 66: "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", - 67: "diamondback, diamondback rattlesnake, Crotalus adamanteus", - 68: "sidewinder, horned rattlesnake, Crotalus cerastes", - 69: "trilobite", - 70: "harvestman, daddy longlegs, Phalangium opilio", - 71: "scorpion", - 72: "black and gold garden spider, Argiope aurantia", - 73: "barn spider, Araneus cavaticus", - 74: "garden spider, Aranea diademata", - 75: "black widow, Latrodectus mactans", - 76: "tarantula", - 77: "wolf spider, hunting spider", - 78: "tick", - 79: "centipede", - 80: "black grouse", - 81: "ptarmigan", - 82: "ruffed grouse, partridge, Bonasa umbellus", - 83: "prairie chicken, prairie grouse, prairie fowl", - 84: "peacock", - 85: "quail", - 86: "partridge", - 87: "African grey, African gray, Psittacus erithacus", - 88: "macaw", - 89: "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", - 90: "lorikeet", - 91: "coucal", - 92: "bee eater", - 93: "hornbill", - 94: "hummingbird", - 95: "jacamar", - 96: "toucan", - 97: "drake", - 98: "red-breasted merganser, Mergus serrator", - 99: "goose", - 100: "black swan, Cygnus atratus", - 101: "tusker", - 102: "echidna, spiny anteater, anteater", - 103: "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", - 104: "wallaby, brush kangaroo", - 105: "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", - 106: "wombat", - 107: "jellyfish", - 108: "sea anemone, anemone", - 109: "brain coral", - 110: "flatworm, platyhelminth", - 111: "nematode, nematode worm, roundworm", - 112: "conch", - 113: "snail", - 114: "slug", - 115: "sea slug, nudibranch", - 116: "chiton, coat-of-mail shell, sea cradle, polyplacophore", - 117: "chambered nautilus, pearly nautilus, nautilus", - 118: "Dungeness crab, Cancer magister", - 119: "rock crab, Cancer irroratus", - 120: "fiddler crab", - 121: "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", - 122: "American lobster, Northern lobster, Maine lobster, Homarus americanus", - 123: "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", - 124: "crayfish, crawfish, crawdad, crawdaddy", - 125: "hermit crab", - 126: "isopod", - 127: "white stork, Ciconia ciconia", - 128: "black stork, Ciconia nigra", - 129: "spoonbill", - 130: "flamingo", - 131: "little blue heron, Egretta caerulea", - 132: "American egret, great white heron, Egretta albus", - 133: "bittern", - 134: "crane", - 135: "limpkin, Aramus pictus", - 136: "European gallinule, Porphyrio porphyrio", - 137: "American coot, marsh hen, mud hen, water hen, Fulica americana", - 138: "bustard", - 139: "ruddy turnstone, Arenaria interpres", - 140: "red-backed sandpiper, dunlin, Erolia alpina", - 141: "redshank, Tringa totanus", - 142: "dowitcher", - 143: "oystercatcher, oyster catcher", - 144: "pelican", - 145: "king penguin, Aptenodytes patagonica", - 146: "albatross, mollymawk", - 147: "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", - 148: "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", - 149: "dugong, Dugong dugon", - 150: "sea lion", - 151: "Chihuahua", - 152: "Japanese spaniel", - 153: "Maltese dog, Maltese terrier, Maltese", - 154: "Pekinese, Pekingese, Peke", - 155: "Shih-Tzu", - 156: "Blenheim spaniel", - 157: "papillon", - 158: "toy terrier", - 159: "Rhodesian ridgeback", - 160: "Afghan hound, Afghan", - 161: "basset, basset hound", - 162: "beagle", - 163: "bloodhound, sleuthhound", - 164: "bluetick", - 165: "black-and-tan coonhound", - 166: "Walker hound, Walker foxhound", - 167: "English foxhound", - 168: "redbone", - 169: "borzoi, Russian wolfhound", - 170: "Irish wolfhound", - 171: "Italian greyhound", - 172: "whippet", - 173: "Ibizan hound, Ibizan Podenco", - 174: "Norwegian elkhound, elkhound", - 175: "otterhound, otter hound", - 176: "Saluki, gazelle hound", - 177: "Scottish deerhound, deerhound", - 178: "Weimaraner", - 179: "Staffordshire bullterrier, Staffordshire bull terrier", - 180: "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", - 181: "Bedlington terrier", - 182: "Border terrier", - 183: "Kerry blue terrier", - 184: "Irish terrier", - 185: "Norfolk terrier", - 186: "Norwich terrier", - 187: "Yorkshire terrier", - 188: "wire-haired fox terrier", - 189: "Lakeland terrier", - 190: "Sealyham terrier, Sealyham", - 191: "Airedale, Airedale terrier", - 192: "cairn, cairn terrier", - 193: "Australian terrier", - 194: "Dandie Dinmont, Dandie Dinmont terrier", - 195: "Boston bull, Boston terrier", - 196: "miniature schnauzer", - 197: "giant schnauzer", - 198: "standard schnauzer", - 199: "Scotch terrier, Scottish terrier, Scottie", - 200: "Tibetan terrier, chrysanthemum dog", - 201: "silky terrier, Sydney silky", - 202: "soft-coated wheaten terrier", - 203: "West Highland white terrier", - 204: "Lhasa, Lhasa apso", - 205: "flat-coated retriever", - 206: "curly-coated retriever", - 207: "golden retriever", - 208: "Labrador retriever", - 209: "Chesapeake Bay retriever", - 210: "German short-haired pointer", - 211: "vizsla, Hungarian pointer", - 212: "English setter", - 213: "Irish setter, red setter", - 214: "Gordon setter", - 215: "Brittany spaniel", - 216: "clumber, clumber spaniel", - 217: "English springer, English springer spaniel", - 218: "Welsh springer spaniel", - 219: "cocker spaniel, English cocker spaniel, cocker", - 220: "Sussex spaniel", - 221: "Irish water spaniel", - 222: "kuvasz", - 223: "schipperke", - 224: "groenendael", - 225: "malinois", - 226: "briard", - 227: "kelpie", - 228: "komondor", - 229: "Old English sheepdog, bobtail", - 230: "Shetland sheepdog, Shetland sheep dog, Shetland", - 231: "collie", - 232: "Border collie", - 233: "Bouvier des Flandres, Bouviers des Flandres", - 234: "Rottweiler", - 235: "German shepherd, German shepherd dog, German police dog, alsatian", - 236: "Doberman, Doberman pinscher", - 237: "miniature pinscher", - 238: "Greater Swiss Mountain dog", - 239: "Bernese mountain dog", - 240: "Appenzeller", - 241: "EntleBucher", - 242: "boxer", - 243: "bull mastiff", - 244: "Tibetan mastiff", - 245: "French bulldog", - 246: "Great Dane", - 247: "Saint Bernard, St Bernard", - 248: "Eskimo dog, husky", - 249: "malamute, malemute, Alaskan malamute", - 250: "Siberian husky", - 251: "dalmatian, coach dog, carriage dog", - 252: "affenpinscher, monkey pinscher, monkey dog", - 253: "basenji", - 254: "pug, pug-dog", - 255: "Leonberg", - 256: "Newfoundland, Newfoundland dog", - 257: "Great Pyrenees", - 258: "Samoyed, Samoyede", - 259: "Pomeranian", - 260: "chow, chow chow", - 261: "keeshond", - 262: "Brabancon griffon", - 263: "Pembroke, Pembroke Welsh corgi", - 264: "Cardigan, Cardigan Welsh corgi", - 265: "toy poodle", - 266: "miniature poodle", - 267: "standard poodle", - 268: "Mexican hairless", - 269: "timber wolf, grey wolf, gray wolf, Canis lupus", - 270: "white wolf, Arctic wolf, Canis lupus tundrarum", - 271: "red wolf, maned wolf, Canis rufus, Canis niger", - 272: "coyote, prairie wolf, brush wolf, Canis latrans", - 273: "dingo, warrigal, warragal, Canis dingo", - 274: "dhole, Cuon alpinus", - 275: "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", - 276: "hyena, hyaena", - 277: "red fox, Vulpes vulpes", - 278: "kit fox, Vulpes macrotis", - 279: "Arctic fox, white fox, Alopex lagopus", - 280: "grey fox, gray fox, Urocyon cinereoargenteus", - 281: "tabby, tabby cat", - 282: "tiger cat", - 283: "Persian cat", - 284: "Siamese cat, Siamese", - 285: "Egyptian cat", - 286: "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", - 287: "lynx, catamount", - 288: "leopard, Panthera pardus", - 289: "snow leopard, ounce, Panthera uncia", - 290: "jaguar, panther, Panthera onca, Felis onca", - 291: "lion, king of beasts, Panthera leo", - 292: "tiger, Panthera tigris", - 293: "cheetah, chetah, Acinonyx jubatus", - 294: "brown bear, bruin, Ursus arctos", - 295: "American black bear, black bear, Ursus americanus, Euarctos americanus", - 296: "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", - 297: "sloth bear, Melursus ursinus, Ursus ursinus", - 298: "mongoose", - 299: "meerkat, mierkat", - 300: "tiger beetle", - 301: "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", - 302: "ground beetle, carabid beetle", - 303: "long-horned beetle, longicorn, longicorn beetle", - 304: "leaf beetle, chrysomelid", - 305: "dung beetle", - 306: "rhinoceros beetle", - 307: "weevil", - 308: "fly", - 309: "bee", - 310: "ant, emmet, pismire", - 311: "grasshopper, hopper", - 312: "cricket", - 313: "walking stick, walkingstick, stick insect", - 314: "cockroach, roach", - 315: "mantis, mantid", - 316: "cicada, cicala", - 317: "leafhopper", - 318: "lacewing, lacewing fly", - 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", - 320: "damselfly", - 321: "admiral", - 322: "ringlet, ringlet butterfly", - 323: "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", - 324: "cabbage butterfly", - 325: "sulphur butterfly, sulfur butterfly", - 326: "lycaenid, lycaenid butterfly", - 327: "starfish, sea star", - 328: "sea urchin", - 329: "sea cucumber, holothurian", - 330: "wood rabbit, cottontail, cottontail rabbit", - 331: "hare", - 332: "Angora, Angora rabbit", - 333: "hamster", - 334: "porcupine, hedgehog", - 335: "fox squirrel, eastern fox squirrel, Sciurus niger", - 336: "marmot", - 337: "beaver", - 338: "guinea pig, Cavia cobaya", - 339: "sorrel", - 340: "zebra", - 341: "hog, pig, grunter, squealer, Sus scrofa", - 342: "wild boar, boar, Sus scrofa", - 343: "warthog", - 344: "hippopotamus, hippo, river horse, Hippopotamus amphibius", - 345: "ox", - 346: "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", - 347: "bison", - 348: "ram, tup", - 349: "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", - 350: "ibex, Capra ibex", - 351: "hartebeest", - 352: "impala, Aepyceros melampus", - 353: "gazelle", - 354: "Arabian camel, dromedary, Camelus dromedarius", - 355: "llama", - 356: "weasel", - 357: "mink", - 358: "polecat, fitch, foulmart, foumart, Mustela putorius", - 359: "black-footed ferret, ferret, Mustela nigripes", - 360: "otter", - 361: "skunk, polecat, wood pussy", - 362: "badger", - 363: "armadillo", - 364: "three-toed sloth, ai, Bradypus tridactylus", - 365: "orangutan, orang, orangutang, Pongo pygmaeus", - 366: "gorilla, Gorilla gorilla", - 367: "chimpanzee, chimp, Pan troglodytes", - 368: "gibbon, Hylobates lar", - 369: "siamang, Hylobates syndactylus, Symphalangus syndactylus", - 370: "guenon, guenon monkey", - 371: "patas, hussar monkey, Erythrocebus patas", - 372: "baboon", - 373: "macaque", - 374: "langur", - 375: "colobus, colobus monkey", - 376: "proboscis monkey, Nasalis larvatus", - 377: "marmoset", - 378: "capuchin, ringtail, Cebus capucinus", - 379: "howler monkey, howler", - 380: "titi, titi monkey", - 381: "spider monkey, Ateles geoffroyi", - 382: "squirrel monkey, Saimiri sciureus", - 383: "Madagascar cat, ring-tailed lemur, Lemur catta", - 384: "indri, indris, Indri indri, Indri brevicaudatus", - 385: "Indian elephant, Elephas maximus", - 386: "African elephant, Loxodonta africana", - 387: "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", - 388: "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", - 389: "barracouta, snoek", - 390: "eel", - 391: "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", - 392: "rock beauty, Holocanthus tricolor", - 393: "anemone fish", - 394: "sturgeon", - 395: "gar, garfish, garpike, billfish, Lepisosteus osseus", - 396: "lionfish", - 397: "puffer, pufferfish, blowfish, globefish", - 398: "abacus", - 399: "abaya", - 400: "academic gown, academic robe, judge's robe", - 401: "accordion, piano accordion, squeeze box", - 402: "acoustic guitar", - 403: "aircraft carrier, carrier, flattop, attack aircraft carrier", - 404: "airliner", - 405: "airship, dirigible", - 406: "altar", - 407: "ambulance", - 408: "amphibian, amphibious vehicle", - 409: "analog clock", - 410: "apiary, bee house", - 411: "apron", - 412: "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", - 413: "assault rifle, assault gun", - 414: "backpack, back pack, knapsack, packsack, rucksack, haversack", - 415: "bakery, bakeshop, bakehouse", - 416: "balance beam, beam", - 417: "balloon", - 418: "ballpoint, ballpoint pen, ballpen, Biro", - 419: "Band Aid", - 420: "banjo", - 421: "bannister, banister, balustrade, balusters, handrail", - 422: "barbell", - 423: "barber chair", - 424: "barbershop", - 425: "barn", - 426: "barometer", - 427: "barrel, cask", - 428: "barrow, garden cart, lawn cart, wheelbarrow", - 429: "baseball", - 430: "basketball", - 431: "bassinet", - 432: "bassoon", - 433: "bathing cap, swimming cap", - 434: "bath towel", - 435: "bathtub, bathing tub, bath, tub", - 436: "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", - 437: "beacon, lighthouse, beacon light, pharos", - 438: "beaker", - 439: "bearskin, busby, shako", - 440: "beer bottle", - 441: "beer glass", - 442: "bell cote, bell cot", - 443: "bib", - 444: "bicycle-built-for-two, tandem bicycle, tandem", - 445: "bikini, two-piece", - 446: "binder, ring-binder", - 447: "binoculars, field glasses, opera glasses", - 448: "birdhouse", - 449: "boathouse", - 450: "bobsled, bobsleigh, bob", - 451: "bolo tie, bolo, bola tie, bola", - 452: "bonnet, poke bonnet", - 453: "bookcase", - 454: "bookshop, bookstore, bookstall", - 455: "bottlecap", - 456: "bow", - 457: "bow tie, bow-tie, bowtie", - 458: "brass, memorial tablet, plaque", - 459: "brassiere, bra, bandeau", - 460: "breakwater, groin, groyne, mole, bulwark, seawall, jetty", - 461: "breastplate, aegis, egis", - 462: "broom", - 463: "bucket, pail", - 464: "buckle", - 465: "bulletproof vest", - 466: "bullet train, bullet", - 467: "butcher shop, meat market", - 468: "cab, hack, taxi, taxicab", - 469: "caldron, cauldron", - 470: "candle, taper, wax light", - 471: "cannon", - 472: "canoe", - 473: "can opener, tin opener", - 474: "cardigan", - 475: "car mirror", - 476: "carousel, carrousel, merry-go-round, roundabout, whirligig", - 477: "carpenter's kit, tool kit", - 478: "carton", - 479: "car wheel", - 480: "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", - 481: "cassette", - 482: "cassette player", - 483: "castle", - 484: "catamaran", - 485: "CD player", - 486: "cello, violoncello", - 487: "cellular telephone, cellular phone, cellphone, cell, mobile phone", - 488: "chain", - 489: "chainlink fence", - 490: "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", - 491: "chain saw, chainsaw", - 492: "chest", - 493: "chiffonier, commode", - 494: "chime, bell, gong", - 495: "china cabinet, china closet", - 496: "Christmas stocking", - 497: "church, church building", - 498: "cinema, movie theater, movie theatre, movie house, picture palace", - 499: "cleaver, meat cleaver, chopper", - 500: "cliff dwelling", - 501: "cloak", - 502: "clog, geta, patten, sabot", - 503: "cocktail shaker", - 504: "coffee mug", - 505: "coffeepot", - 506: "coil, spiral, volute, whorl, helix", - 507: "combination lock", - 508: "computer keyboard, keypad", - 509: "confectionery, confectionary, candy store", - 510: "container ship, containership, container vessel", - 511: "convertible", - 512: "corkscrew, bottle screw", - 513: "cornet, horn, trumpet, trump", - 514: "cowboy boot", - 515: "cowboy hat, ten-gallon hat", - 516: "cradle", - 517: "crane", - 518: "crash helmet", - 519: "crate", - 520: "crib, cot", - 521: "Crock Pot", - 522: "croquet ball", - 523: "crutch", - 524: "cuirass", - 525: "dam, dike, dyke", - 526: "desk", - 527: "desktop computer", - 528: "dial telephone, dial phone", - 529: "diaper, nappy, napkin", - 530: "digital clock", - 531: "digital watch", - 532: "dining table, board", - 533: "dishrag, dishcloth", - 534: "dishwasher, dish washer, dishwashing machine", - 535: "disk brake, disc brake", - 536: "dock, dockage, docking facility", - 537: "dogsled, dog sled, dog sleigh", - 538: "dome", - 539: "doormat, welcome mat", - 540: "drilling platform, offshore rig", - 541: "drum, membranophone, tympan", - 542: "drumstick", - 543: "dumbbell", - 544: "Dutch oven", - 545: "electric fan, blower", - 546: "electric guitar", - 547: "electric locomotive", - 548: "entertainment center", - 549: "envelope", - 550: "espresso maker", - 551: "face powder", - 552: "feather boa, boa", - 553: "file, file cabinet, filing cabinet", - 554: "fireboat", - 555: "fire engine, fire truck", - 556: "fire screen, fireguard", - 557: "flagpole, flagstaff", - 558: "flute, transverse flute", - 559: "folding chair", - 560: "football helmet", - 561: "forklift", - 562: "fountain", - 563: "fountain pen", - 564: "four-poster", - 565: "freight car", - 566: "French horn, horn", - 567: "frying pan, frypan, skillet", - 568: "fur coat", - 569: "garbage truck, dustcart", - 570: "gasmask, respirator, gas helmet", - 571: "gas pump, gasoline pump, petrol pump, island dispenser", - 572: "goblet", - 573: "go-kart", - 574: "golf ball", - 575: "golfcart, golf cart", - 576: "gondola", - 577: "gong, tam-tam", - 578: "gown", - 579: "grand piano, grand", - 580: "greenhouse, nursery, glasshouse", - 581: "grille, radiator grille", - 582: "grocery store, grocery, food market, market", - 583: "guillotine", - 584: "hair slide", - 585: "hair spray", - 586: "half track", - 587: "hammer", - 588: "hamper", - 589: "hand blower, blow dryer, blow drier, hair dryer, hair drier", - 590: "hand-held computer, hand-held microcomputer", - 591: "handkerchief, hankie, hanky, hankey", - 592: "hard disc, hard disk, fixed disk", - 593: "harmonica, mouth organ, harp, mouth harp", - 594: "harp", - 595: "harvester, reaper", - 596: "hatchet", - 597: "holster", - 598: "home theater, home theatre", - 599: "honeycomb", - 600: "hook, claw", - 601: "hoopskirt, crinoline", - 602: "horizontal bar, high bar", - 603: "horse cart, horse-cart", - 604: "hourglass", - 605: "iPod", - 606: "iron, smoothing iron", - 607: "jack-o'-lantern", - 608: "jean, blue jean, denim", - 609: "jeep, landrover", - 610: "jersey, T-shirt, tee shirt", - 611: "jigsaw puzzle", - 612: "jinrikisha, ricksha, rickshaw", - 613: "joystick", - 614: "kimono", - 615: "knee pad", - 616: "knot", - 617: "lab coat, laboratory coat", - 618: "ladle", - 619: "lampshade, lamp shade", - 620: "laptop, laptop computer", - 621: "lawn mower, mower", - 622: "lens cap, lens cover", - 623: "letter opener, paper knife, paperknife", - 624: "library", - 625: "lifeboat", - 626: "lighter, light, igniter, ignitor", - 627: "limousine, limo", - 628: "liner, ocean liner", - 629: "lipstick, lip rouge", - 630: "Loafer", - 631: "lotion", - 632: "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", - 633: "loupe, jeweler's loupe", - 634: "lumbermill, sawmill", - 635: "magnetic compass", - 636: "mailbag, postbag", - 637: "mailbox, letter box", - 638: "maillot", - 639: "maillot, tank suit", - 640: "manhole cover", - 641: "maraca", - 642: "marimba, xylophone", - 643: "mask", - 644: "matchstick", - 645: "maypole", - 646: "maze, labyrinth", - 647: "measuring cup", - 648: "medicine chest, medicine cabinet", - 649: "megalith, megalithic structure", - 650: "microphone, mike", - 651: "microwave, microwave oven", - 652: "military uniform", - 653: "milk can", - 654: "minibus", - 655: "miniskirt, mini", - 656: "minivan", - 657: "missile", - 658: "mitten", - 659: "mixing bowl", - 660: "mobile home, manufactured home", - 661: "Model T", - 662: "modem", - 663: "monastery", - 664: "monitor", - 665: "moped", - 666: "mortar", - 667: "mortarboard", - 668: "mosque", - 669: "mosquito net", - 670: "motor scooter, scooter", - 671: "mountain bike, all-terrain bike, off-roader", - 672: "mountain tent", - 673: "mouse, computer mouse", - 674: "mousetrap", - 675: "moving van", - 676: "muzzle", - 677: "nail", - 678: "neck brace", - 679: "necklace", - 680: "nipple", - 681: "notebook, notebook computer", - 682: "obelisk", - 683: "oboe, hautboy, hautbois", - 684: "ocarina, sweet potato", - 685: "odometer, hodometer, mileometer, milometer", - 686: "oil filter", - 687: "organ, pipe organ", - 688: "oscilloscope, scope, cathode-ray oscilloscope, CRO", - 689: "overskirt", - 690: "oxcart", - 691: "oxygen mask", - 692: "packet", - 693: "paddle, boat paddle", - 694: "paddlewheel, paddle wheel", - 695: "padlock", - 696: "paintbrush", - 697: "pajama, pyjama, pj's, jammies", - 698: "palace", - 699: "panpipe, pandean pipe, syrinx", - 700: "paper towel", - 701: "parachute, chute", - 702: "parallel bars, bars", - 703: "park bench", - 704: "parking meter", - 705: "passenger car, coach, carriage", - 706: "patio, terrace", - 707: "pay-phone, pay-station", - 708: "pedestal, plinth, footstall", - 709: "pencil box, pencil case", - 710: "pencil sharpener", - 711: "perfume, essence", - 712: "Petri dish", - 713: "photocopier", - 714: "pick, plectrum, plectron", - 715: "pickelhaube", - 716: "picket fence, paling", - 717: "pickup, pickup truck", - 718: "pier", - 719: "piggy bank, penny bank", - 720: "pill bottle", - 721: "pillow", - 722: "ping-pong ball", - 723: "pinwheel", - 724: "pirate, pirate ship", - 725: "pitcher, ewer", - 726: "plane, carpenter's plane, woodworking plane", - 727: "planetarium", - 728: "plastic bag", - 729: "plate rack", - 730: "plow, plough", - 731: "plunger, plumber's helper", - 732: "Polaroid camera, Polaroid Land camera", - 733: "pole", - 734: "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", - 735: "poncho", - 736: "pool table, billiard table, snooker table", - 737: "pop bottle, soda bottle", - 738: "pot, flowerpot", - 739: "potter's wheel", - 740: "power drill", - 741: "prayer rug, prayer mat", - 742: "printer", - 743: "prison, prison house", - 744: "projectile, missile", - 745: "projector", - 746: "puck, hockey puck", - 747: "punching bag, punch bag, punching ball, punchball", - 748: "purse", - 749: "quill, quill pen", - 750: "quilt, comforter, comfort, puff", - 751: "racer, race car, racing car", - 752: "racket, racquet", - 753: "radiator", - 754: "radio, wireless", - 755: "radio telescope, radio reflector", - 756: "rain barrel", - 757: "recreational vehicle, RV, R.V.", - 758: "reel", - 759: "reflex camera", - 760: "refrigerator, icebox", - 761: "remote control, remote", - 762: "restaurant, eating house, eating place, eatery", - 763: "revolver, six-gun, six-shooter", - 764: "rifle", - 765: "rocking chair, rocker", - 766: "rotisserie", - 767: "rubber eraser, rubber, pencil eraser", - 768: "rugby ball", - 769: "rule, ruler", - 770: "running shoe", - 771: "safe", - 772: "safety pin", - 773: "saltshaker, salt shaker", - 774: "sandal", - 775: "sarong", - 776: "sax, saxophone", - 777: "scabbard", - 778: "scale, weighing machine", - 779: "school bus", - 780: "schooner", - 781: "scoreboard", - 782: "screen, CRT screen", - 783: "screw", - 784: "screwdriver", - 785: "seat belt, seatbelt", - 786: "sewing machine", - 787: "shield, buckler", - 788: "shoe shop, shoe-shop, shoe store", - 789: "shoji", - 790: "shopping basket", - 791: "shopping cart", - 792: "shovel", - 793: "shower cap", - 794: "shower curtain", - 795: "ski", - 796: "ski mask", - 797: "sleeping bag", - 798: "slide rule, slipstick", - 799: "sliding door", - 800: "slot, one-armed bandit", - 801: "snorkel", - 802: "snowmobile", - 803: "snowplow, snowplough", - 804: "soap dispenser", - 805: "soccer ball", - 806: "sock", - 807: "solar dish, solar collector, solar furnace", - 808: "sombrero", - 809: "soup bowl", - 810: "space bar", - 811: "space heater", - 812: "space shuttle", - 813: "spatula", - 814: "speedboat", - 815: "spider web, spider's web", - 816: "spindle", - 817: "sports car, sport car", - 818: "spotlight, spot", - 819: "stage", - 820: "steam locomotive", - 821: "steel arch bridge", - 822: "steel drum", - 823: "stethoscope", - 824: "stole", - 825: "stone wall", - 826: "stopwatch, stop watch", - 827: "stove", - 828: "strainer", - 829: "streetcar, tram, tramcar, trolley, trolley car", - 830: "stretcher", - 831: "studio couch, day bed", - 832: "stupa, tope", - 833: "submarine, pigboat, sub, U-boat", - 834: "suit, suit of clothes", - 835: "sundial", - 836: "sunglass", - 837: "sunglasses, dark glasses, shades", - 838: "sunscreen, sunblock, sun blocker", - 839: "suspension bridge", - 840: "swab, swob, mop", - 841: "sweatshirt", - 842: "swimming trunks, bathing trunks", - 843: "swing", - 844: "switch, electric switch, electrical switch", - 845: "syringe", - 846: "table lamp", - 847: "tank, army tank, armored combat vehicle, armoured combat vehicle", - 848: "tape player", - 849: "teapot", - 850: "teddy, teddy bear", - 851: "television, television system", - 852: "tennis ball", - 853: "thatch, thatched roof", - 854: "theater curtain, theatre curtain", - 855: "thimble", - 856: "thresher, thrasher, threshing machine", - 857: "throne", - 858: "tile roof", - 859: "toaster", - 860: "tobacco shop, tobacconist shop, tobacconist", - 861: "toilet seat", - 862: "torch", - 863: "totem pole", - 864: "tow truck, tow car, wrecker", - 865: "toyshop", - 866: "tractor", - 867: "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", - 868: "tray", - 869: "trench coat", - 870: "tricycle, trike, velocipede", - 871: "trimaran", - 872: "tripod", - 873: "triumphal arch", - 874: "trolleybus, trolley coach, trackless trolley", - 875: "trombone", - 876: "tub, vat", - 877: "turnstile", - 878: "typewriter keyboard", - 879: "umbrella", - 880: "unicycle, monocycle", - 881: "upright, upright piano", - 882: "vacuum, vacuum cleaner", - 883: "vase", - 884: "vault", - 885: "velvet", - 886: "vending machine", - 887: "vestment", - 888: "viaduct", - 889: "violin, fiddle", - 890: "volleyball", - 891: "waffle iron", - 892: "wall clock", - 893: "wallet, billfold, notecase, pocketbook", - 894: "wardrobe, closet, press", - 895: "warplane, military plane", - 896: "washbasin, handbasin, washbowl, lavabo, wash-hand basin", - 897: "washer, automatic washer, washing machine", - 898: "water bottle", - 899: "water jug", - 900: "water tower", - 901: "whiskey jug", - 902: "whistle", - 903: "wig", - 904: "window screen", - 905: "window shade", - 906: "Windsor tie", - 907: "wine bottle", - 908: "wing", - 909: "wok", - 910: "wooden spoon", - 911: "wool, woolen, woollen", - 912: "worm fence, snake fence, snake-rail fence, Virginia fence", - 913: "wreck", - 914: "yawl", - 915: "yurt", - 916: "web site, website, internet site, site", - 917: "comic book", - 918: "crossword puzzle, crossword", - 919: "street sign", - 920: "traffic light, traffic signal, stoplight", - 921: "book jacket, dust cover, dust jacket, dust wrapper", - 922: "menu", - 923: "plate", - 924: "guacamole", - 925: "consomme", - 926: "hot pot, hotpot", - 927: "trifle", - 928: "ice cream, icecream", - 929: "ice lolly, lolly, lollipop, popsicle", - 930: "French loaf", - 931: "bagel, beigel", - 932: "pretzel", - 933: "cheeseburger", - 934: "hotdog, hot dog, red hot", - 935: "mashed potato", - 936: "head cabbage", - 937: "broccoli", - 938: "cauliflower", - 939: "zucchini, courgette", - 940: "spaghetti squash", - 941: "acorn squash", - 942: "butternut squash", - 943: "cucumber, cuke", - 944: "artichoke, globe artichoke", - 945: "bell pepper", - 946: "cardoon", - 947: "mushroom", - 948: "Granny Smith", - 949: "strawberry", - 950: "orange", - 951: "lemon", - 952: "fig", - 953: "pineapple, ananas", - 954: "banana", - 955: "jackfruit, jak, jack", - 956: "custard apple", - 957: "pomegranate", - 958: "hay", - 959: "carbonara", - 960: "chocolate sauce, chocolate syrup", - 961: "dough", - 962: "meat loaf, meatloaf", - 963: "pizza, pizza pie", - 964: "potpie", - 965: "burrito", - 966: "red wine", - 967: "espresso", - 968: "cup", - 969: "eggnog", - 970: "alp", - 971: "bubble", - 972: "cliff, drop, drop-off", - 973: "coral reef", - 974: "geyser", - 975: "lakeside, lakeshore", - 976: "promontory, headland, head, foreland", - 977: "sandbar, sand bar", - 978: "seashore, coast, seacoast, sea-coast", - 979: "valley, vale", - 980: "volcano", - 981: "ballplayer, baseball player", - 982: "groom, bridegroom", - 983: "scuba diver", - 984: "rapeseed", - 985: "daisy", - 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", - 987: "corn", - 988: "acorn", - 989: "hip, rose hip, rosehip", - 990: "buckeye, horse chestnut, conker", - 991: "coral fungus", - 992: "agaric", - 993: "gyromitra", - 994: "stinkhorn, carrion fungus", - 995: "earthstar", - 996: "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", - 997: "bolete", - 998: "ear, spike, capitulum", - 999: "toilet tissue, toilet paper, bathroom tissue", -} diff --git a/src/transformers/utils/modeling_auto_mapping.py b/src/transformers/utils/modeling_auto_mapping.py index 562124bdf5d..309ba38449e 100644 --- a/src/transformers/utils/modeling_auto_mapping.py +++ b/src/transformers/utils/modeling_auto_mapping.py @@ -76,6 +76,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ ("ViTConfig", "ViTForImageClassification"), ("DeiTConfig", "('DeiTForImageClassification', 'DeiTForImageClassificationWithTeacher')"), + ("BeitConfig", "BeitForImageClassification"), ] ) @@ -261,6 +262,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_MAPPING_NAMES = OrderedDict( [ + ("BeitConfig", "BeitModel"), ("RemBertConfig", "RemBertModel"), ("VisualBertConfig", "VisualBertModel"), ("CanineConfig", "CanineModel"), diff --git a/tests/test_feature_extraction_beit.py b/tests/test_feature_extraction_beit.py new file mode 100644 index 00000000000..8ced1580a2b --- /dev/null +++ b/tests/test_feature_extraction_beit.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision + +from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import BeitFeatureExtractor + + +class BeitFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=20, + do_center_crop=True, + crop_size=18, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + } + + +@require_torch +@require_vision +class BeitFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = BeitFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = BeitFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "do_center_crop")) + self.assertTrue(hasattr(feature_extractor, "center_crop")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py new file mode 100644 index 00000000000..d0cc5f6eb49 --- /dev/null +++ b/tests/test_modeling_beit.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch BEiT model. """ + + +import inspect +import unittest + +from transformers import BeitConfig +from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.models.auto import get_values +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel + from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + + +if is_vision_available(): + from PIL import Image + + from transformers import BeitFeatureExtractor + + +class BeitModelTester: + def __init__( + self, + parent, + vocab_size=100, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.vocab_size = 100 + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = self.get_config() + + return config, pixel_values, labels + + def get_config(self): + return BeitConfig( + vocab_size=self.vocab_size, + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = BeitModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected sequence length = num_patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.image_size) + patch_size = to_2tuple(self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) + + def create_and_check_for_masked_lm(self, config, pixel_values, labels): + model = BeitForMaskedImageModeling(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected sequence length = num_patches + image_size = to_2tuple(self.image_size) + patch_size = to_2tuple(self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = BeitForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class BeitModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as BEiT does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else () + ) + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = BeitModelTester(self) + self.config_tester = ConfigTester(self, config_class=BeitConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # BEiT does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_training(self): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING): + continue + # we don't test BeitForMaskedImageModeling + if model_class.__name__ == "BeitForMaskedImageModeling": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + # we skip lambda parameters as these require special initial values + # determined by config.layer_scale_init_value + if "lambda" in name: + continue + if param.requires_grad: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + # in BEiT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token) + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 1 + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # BEiT has a different seq_length + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_length = num_patches + 1 + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in BEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = BeitModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png") + return image + + +@require_vision +class BeitModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + BeitFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224") if is_vision_available() else None + ) + + @slow + def test_inference_image_classification_head_imagenet_1k(self): + model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224").to(torch_device) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + outputs = model(**inputs) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = torch.tensor([-1.2385, -1.0987, -1.0108]).to(torch_device) + + self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4)) + + expected_class_idx = 281 + self.assertEqual(logits.argmax(-1).item(), expected_class_idx) + + @slow + def test_inference_image_classification_head_imagenet_22k(self): + model = BeitForImageClassification.from_pretrained("microsoft/beit-large-patch16-224-pt22k-ft22k").to( + torch_device + ) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + outputs = model(**inputs) + logits = outputs.logits + + # verify the logits + expected_shape = torch.Size((1, 21841)) + self.assertEqual(logits.shape, expected_shape) + + expected_slice = torch.tensor([1.6881, -0.2787, 0.5901]).to(torch_device) + + self.assertTrue(torch.allclose(logits[0, :3], expected_slice, atol=1e-4)) + + expected_class_idx = 2396 + self.assertEqual(logits.argmax(-1).item(), expected_class_idx) diff --git a/utils/check_repo.py b/utils/check_repo.py index 38afb6f55a3..47cf8fd2175 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -97,6 +97,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping + "BeitForMaskedImageModeling", "CLIPTextModel", "CLIPVisionModel", "FlaxCLIPTextModel",