[Add Mamba] Adds support for the Mamba models (#28094)

* initial-commit

* start cleaning

* small nits

* small nits

* current updates

* add kernels

* small refactoring little step

* add comments

* styling

* nit

* nits

* Style

* Small changes

* Push dummy mambda simple slow

* nit

* Use original names

* Use original names and remove norm

* Updates for inference params

* Style nd updates

* nits

* Match logits

* Add a test

* Add expected generated text

* nits doc, imports and styling

* style

* oups

* dont install kernels, invite users to install the required kernels

* let use use the original packages

* styling

* nits

* fix some copieds

* update doc

* fix-copies

* styling done

* nits

* fix import check

* run but wrong cuda ress

* mamba CUDA works :)

* fix the fast path

* config naming nits

* conversion script is not required at this stage

* finish fixing the fast path: generation make sense now!

* nit

* Let's start working on the CIs

* style

* better style

* more nits

* test nit

* quick fix for now

* nits

* nit

* nit

* nit

* nits

* update test rest

* fixup

* update test

* nit

* some fixes

* nits

* update test values

* fix styling

* nit

* support peft

* integrations tests require torchg

* also add slow markers

* styling

* chose forward wisely

* nits

* update tests

* fix gradient checkpointing

* fixup

* nit

* fix doc

* check copies

* fix the docstring

* fix some more tests

* style

* fix beam search

* add init schene

* update

* nit

* fix

* fixup the doc

* fix the doc

* fixup

* tentative update but slow is no longer good

* nit

* should we always use float32?

* nits

* revert wrong changes

* res in float32

* cleanup

* skip fmt for now

* update generation values

* update test values running original model

* fixup

* update tests + rename inference_params to cache_params + make sure training does not use cache_params

* small nits

* more nits

* fix final CIs

* style

* nit doc

* I hope final doc nits

* nit

* 🫠

* final touch!

* fix torch import

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Apply suggestions from code review

* fix fix and fix

* fix base model prefix!

* nit

* Update src/transformers/models/mamba/__init__.py

* Update docs/source/en/model_doc/mamba.md

Co-authored-by: Lysandre Debut <hi@lysand.re>

* nit

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Arthur 2024-03-05 12:01:06 +01:00 committed by GitHub
parent 87a0783dde
commit fb1c62e973
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1583 additions and 5 deletions

View File

@ -415,6 +415,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.

View File

@ -388,6 +388,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.

View File

@ -409,6 +409,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (de Facebook) a été publié dans l'article [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) de Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve et Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (de Facebook) a été publié dans l'article [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) de Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (de Google) a été publié dans l'article [MADLAD-400 : Un ensemble de données multilingue et de niveau document](https://arxiv.org/abs/2309.04662) de Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (de Albert Gu and Tri Dao) publié dans l'article [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) parAlbert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Des modèles de traduction automatique formés avec les données [OPUS](http://opus.nlpl.eu/) par Jörg Tiedemann. Le [cadre Marian](https://marian-nmt.github.io/) est en cours de développement par l'équipe Microsoft Translator.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (de Microsoft Research Asia) a été publié dans l'article [MarkupLM : Pré-entraînement de texte et de langage de balisage pour la compréhension visuellement riche de documents](https://arxiv.org/abs/2110.08518) de Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (de FAIR et UIUC) a été publié dans l'article [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) de Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.

View File

@ -362,6 +362,7 @@ conda install conda-forge::transformers
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (फेसबुक से) साथ देने वाला पेपर [बियॉन्ड इंग्लिश-सेंट्रिक मल्टीलिंगुअल मशीन ट्रांसलेशन](https://arxiv.org/एब्स/2010.11125) एंजेला फैन, श्रुति भोसले, होल्गर श्वेन्क, झी मा, अहमद अल-किश्की, सिद्धार्थ गोयल, मनदीप बैनेस, ओनूर सेलेबी, गुइल्लाम वेन्जेक, विश्रव चौधरी, नमन गोयल, टॉम बर्च, विटाली लिपचिंस्की, सर्गेई एडुनोव, एडौर्ड द्वारा ग्रेव, माइकल औली, आर्मंड जौलिन द्वारा पोस्ट किया गया।
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao से) Albert Gu and Tri Dao. द्वाराअनुसंधान पत्र [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) के साथ जारी किया गया
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg द्वारा [OPUS](http://opus.nlpl.eu/) डेटा से प्रशिक्षित मशीनी अनुवाद मॉडल पोस्ट किया गया टाइडेमैन द्वारा। [मैरियन फ्रेमवर्क](https://marian-nmt.github.io/) माइक्रोसॉफ्ट ट्रांसलेटर टीम द्वारा विकसित।
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (माइक्रोसॉफ्ट रिसर्च एशिया से) साथ में पेपर [मार्कअपएलएम: विजुअली-रिच डॉक्यूमेंट अंडरस्टैंडिंग के लिए टेक्स्ट और मार्कअप लैंग्वेज का प्री-ट्रेनिंग](https://arxiv.org/abs/2110.08518) जुनलॉन्ग ली, यिहेंग जू, लेई कुई, फुरु द्वारा वी द्वारा पोस्ट किया गया।
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC से) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. द्वाराअनुसंधान पत्र [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) के साथ जारी किया गया

View File

@ -422,6 +422,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (Facebook から) Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert から公開された研究論文: [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161)
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook から) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin から公開された研究論文: [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125)
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao から) Albert Gu and Tri Dao. から公開された研究論文 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg Tiedemann から. [OPUS](http://opus.nlpl.eu/) を使いながら学習された "Machine translation" (マシントランスレーション) モデル. [Marian Framework](https://marian-nmt.github.io/) はMicrosoft Translator Team が現在開発中です.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia から) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei から公開された研究論文: [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518)
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC から) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. から公開された研究論文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)

View File

@ -337,6 +337,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (Facebook 에서) Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 의 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 논문과 함께 발표했습니다.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook 에서) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 의 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 논문과 함께 발표했습니다.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao 에서 제공)은 Albert Gu and Tri Dao.의 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)논문과 함께 발표했습니다.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia 에서) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 의 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 논문과 함께 발표했습니다.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC 에서 제공)은 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.의 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)논문과 함께 발표했습니다.

View File

@ -361,6 +361,7 @@ conda install conda-forge::transformers
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (来自 Facebook) 伴随论文 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 由 Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 发布。
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (来自 Albert Gu and Tri Dao) 伴随论文 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) 由 Albert Gu and Tri Dao 发布。
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** 用 [OPUS](http://opus.nlpl.eu/) 数据训练的机器翻译模型由 Jörg Tiedemann 发布。[Marian Framework](https://marian-nmt.github.io/) 由微软翻译团队开发。
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (来自 Microsoft Research Asia) 伴随论文 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 由 Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 发布。
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (来自 FAIR and UIUC) 伴随论文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) 由 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar 发布。

View File

@ -373,6 +373,7 @@ conda install conda-forge::transformers
1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat.
1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei.
1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.

View File

@ -398,6 +398,8 @@
title: M2M100
- local: model_doc/madlad-400
title: MADLAD-400
- local: model_doc/mamba
title: Mamba
- local: model_doc/marian
title: MarianMT
- local: model_doc/markuplm

View File

@ -180,6 +180,7 @@ Flax), PyTorch, and/or TensorFlow.
| [M-CTC-T](model_doc/mctct) | ✅ | ❌ | ❌ |
| [M2M100](model_doc/m2m_100) | ✅ | ❌ | ❌ |
| [MADLAD-400](model_doc/madlad-400) | ✅ | ✅ | ✅ |
| [Mamba](model_doc/mamba) | ✅ | ❌ | ❌ |
| [Marian](model_doc/marian) | ✅ | ✅ | ✅ |
| [MarkupLM](model_doc/markuplm) | ✅ | ❌ | ❌ |
| [Mask2Former](model_doc/mask2former) | ✅ | ❌ | ❌ |

View File

@ -0,0 +1,107 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Mamba
## Overview
The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao.
This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](https://srush.github.io/annotated-s4/).
The abstract from the paper is the following:
*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.*
Tips:
- Mamba is a new `state space model` architecture that rivals the classic Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
- Mamba stacks `mixer` layers, which are the equivalent of `Attention` layers. The core logic of `mamba` is held in the `MambaMixer` class.
- Two implementations cohabit: one is optimized and uses fast cuda kernels, while the other one is naive but can run on any device!
- The current implementation leverages the original cuda kernels: the equivalent of flash attention for Mamba are hosted in the [`mamba-ssm`](https://github.com/state-spaces/mamba) and the [`causal_conv1d`](https://github.com/Dao-AILab/causal-conv1d) repositories. Make sure to install them if your hardware supports them!
- Contributions to make the naive path faster are welcome 🤗
This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ).
The original code can be found [here](https://github.com/state-spaces/mamba).
# Usage
### A simple generation example:
```python
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
model.config.use_cache = True
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
### Peft finetuning
The slow version is not very stable for training, and the fast one needs `float32`!
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "ArthurZ/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules="all-linear",
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## MambaConfig
[[autodoc]] MambaConfig
## MambaModel
[[autodoc]] MambaModel
- forward
## MambaLMHeadModel
[[autodoc]] MambaForCausalLM
- forward

View File

@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
Choose one of the following architectures:
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)

View File

@ -571,6 +571,7 @@ _import_structure = {
"LxmertTokenizer",
],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig"],
"models.marian": ["MarianConfig"],
"models.markuplm": [
"MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP",
@ -2578,6 +2579,14 @@ else:
"M2M100PreTrainedModel",
]
)
_import_structure["models.mamba"].extend(
[
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
_import_structure["models.markuplm"].extend(
[
@ -5370,6 +5379,7 @@ if TYPE_CHECKING:
LxmertTokenizer,
)
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig
from .models.marian import MarianConfig
from .models.markuplm import (
MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -7160,6 +7170,12 @@ if TYPE_CHECKING:
M2M100Model,
M2M100PreTrainedModel,
)
from .models.mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.markuplm import (
MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LIST,

View File

@ -3183,7 +3183,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
@ -3537,7 +3537,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
@ -3943,7 +3943,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
@ -4302,7 +4302,7 @@ class GenerationMixin:
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs
)
if model_kwargs["past_key_values"] is not None:
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)

View File

@ -128,6 +128,7 @@ from . import (
luke,
lxmert,
m2m_100,
mamba,
marian,
markuplm,
mask2former,

View File

@ -137,6 +137,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("luke", "LukeConfig"),
("lxmert", "LxmertConfig"),
("m2m_100", "M2M100Config"),
("mamba", "MambaConfig"),
("marian", "MarianConfig"),
("markuplm", "MarkupLMConfig"),
("mask2former", "Mask2FormerConfig"),
@ -373,6 +374,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@ -614,6 +616,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("lxmert", "LXMERT"),
("m2m_100", "M2M100"),
("madlad-400", "MADLAD-400"),
("mamba", "Mamba"),
("marian", "Marian"),
("markuplm", "MarkupLM"),
("mask2former", "Mask2Former"),

View File

@ -134,6 +134,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("luke", "LukeModel"),
("lxmert", "LxmertModel"),
("m2m_100", "M2M100Model"),
("mamba", "MambaModel"),
("marian", "MarianModel"),
("markuplm", "MarkupLMModel"),
("mask2former", "Mask2FormerModel"),
@ -286,6 +287,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("longformer", "LongformerForMaskedLM"),
("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"),
("mamba", "MambaForCausalLM"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"),
@ -367,6 +369,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("longt5", "LongT5ForConditionalGeneration"),
("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"),
("mamba", "MambaForCausalLM"),
("marian", "MarianMTModel"),
("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForCausalLM"),
@ -439,6 +442,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"),
("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"),
("marian", "MarianForCausalLM"),
("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"),

View File

@ -233,6 +233,7 @@ else:
("luke", ("LukeTokenizer", None)),
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
(
"mbart",

View File

@ -0,0 +1,60 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig", "MambaOnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mamba"] = [
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig, MambaOnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,156 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MAMBA configuration"""
import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"state-spaces/mamba-2.8b": "https://huggingface.co/state-spaces/mamba-2.8b/resolve/main/config.json",
}
class MambaConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MAMBA
[state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50280):
Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MambaModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the model.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
The epsilon to use in the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 0):
The id of the end of sentence token in the vocabulary.
expand (`int`, *optional*, defaults to 2): Expanding factor used to determin the intermediate size.
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
use_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.1):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
time_step_scale (`float`, *optional*, defaults to 1.0):
Scale used used to scale `dt_proj.bias`.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
Example:
```python
>>> from transformers import MambaConfig, MambaModel
>>> # Initializing a Mamba configuration
>>> configuration = MambaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = MambaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mamba"
def __init__(
self,
vocab_size=50280,
hidden_size=768,
state_size=16,
num_hidden_layers=32,
layer_norm_epsilon=1e-5,
pad_token_id=0,
bos_token_id=0,
eos_token_id=0,
expand=2,
conv_kernel=4,
use_bias=False,
use_conv_bias=True,
hidden_act="silu",
initializer_range=0.1,
residual_in_fp32=True,
time_step_rank="auto",
time_step_scale=1.0,
time_step_min=0.001,
time_step_max=0.1,
time_step_init_scheme="random",
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
self.num_hidden_layers = num_hidden_layers
self.layer_norm_epsilon = layer_norm_epsilon
self.conv_kernel = conv_kernel
self.expand = expand
self.intermediate_size = int(expand * self.hidden_size)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.use_bias = use_bias
self.use_conv_bias = use_conv_bias
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
self.time_step_scale = time_step_scale
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_init_scheme = time_step_init_scheme
self.time_step_floor = time_step_floor
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)

View File

@ -0,0 +1,681 @@
# coding=utf-8
# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA model."""
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
from .configuration_mamba import MambaConfig
logger = logging.get_logger(__name__)
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
_CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m"
_CONFIG_FOR_DOC = "MambaConfig"
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba
class MambaMixer(nn.Module):
"""
Compute , A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
"""
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = config.intermediate_size
self.time_step_rank = config.time_step_rank
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.conv1d = nn.Conv1d(
in_channels=self.intermediate_size,
out_channels=self.intermediate_size,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.intermediate_size,
padding=config.conv_kernel - 1,
)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
# time step projection (discretization)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.intermediate_size))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None):
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2)
if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
contextualized_states = mamba_inner_fn(
projected_states,
self.conv1d.weight,
self.conv1d.bias if self.use_conv_bias else None,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias.float() if self.use_bias else None,
-torch.exp(self.A_log.float()),
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
hidden_states, gate = projected_states.chunk(2, dim=1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_params.seqlen_offset > 0:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_states[self.layer_idx],
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
A = -torch.exp(self.A_log.float())
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
if cache_params is not None and cache_params.seqlen_offset > 0:
scan_outputs = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states[..., 0],
discrete_time_step[..., 0],
A,
B[:, 0],
C[:, 0],
self.D,
gate[..., 0],
time_proj_bias,
dt_softplus=True,
).unsqueeze(-1)
else:
scan_outputs, ssm_state = selective_scan_fn(
hidden_states,
discrete_time_step,
A,
B.transpose(1, 2),
C.transpose(1, 2),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
return_last_state=True,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states
# fmt: off
def slow_forward(self, input_states, cache_params=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)
# 2. Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx]
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
conv_state[:, :, -1] = hidden_states[:, :, 0]
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
else:
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
else:
ssm_state = torch.zeros(
(batch_size, self.intermediate_size, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
)
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
# 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))
if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on
def forward(self, hidden_states, cache_params=None):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel
self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
class MambaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class MambaBlock(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = MambaMixer(config, layer_idx=layer_idx)
def forward(self, hidden_states, cache_params=None):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, cache_params=cache_params)
hidden_states = residual + hidden_states
return hidden_states
class MambaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MambaConfig
base_model_prefix = "backbone"
_no_split_modules = ["MambaBlock"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, MambaMixer):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
if self.config.time_step_init_scheme == "constant":
nn.init.constant_(module.dt_proj.weight, dt_init_std)
elif self.config.time_step_init_scheme == "random":
nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
dt = torch.exp(
torch.rand(self.config.intermediate_size)
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ math.log(self.config.time_step_min)
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
module.dt_proj.bias.copy_(inv_dt)
module.dt_proj.bias._no_reinit = True
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.initializer_range)
if self.config.rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(self.config.num_layers)
@dataclass
class MambaOutput(ModelOutput):
"""
Class for the MAMBA model outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model states weights after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
cache_params: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MambaCausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
cache_params: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
MAMBA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MambaConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MAMBA_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
Indices of input sequence tokens in the vocabulary.
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
cache_params (`MambaCache`, *optional*):
If passed along, the model uses the previous state in all the blocks (which will give the output for the
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
use_cache (`bool`, *optional*):
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.",
MAMBA_START_DOCSTRING,
)
class MambaModel(MambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MambaOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if cache_params is None and use_cache:
cache_params = MambaCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params)
else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if use_cache:
cache_params.seqlen_offset += inputs_embeds.shape[1]
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return MambaOutput(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
@add_start_docstrings(
"""
The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
MAMBA_START_DOCSTRING,
)
class MambaForCausalLM(MambaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.backbone = MambaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
) -> Dict[str, Any]:
model_kwargs["cache_params"] = outputs["cache_params"]
return model_kwargs
def prepare_inputs_for_generation(
self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs
):
# only last token for inputs_ids if the state is passed along.
if cache_params is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs["cache_params"] = cache_params
return model_inputs
@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MambaCausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
mamba_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = mamba_outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + mamba_outputs[1:]
return ((loss,) + output) if loss is not None else output
return MambaCausalLMOutput(
loss=loss,
logits=logits,
cache_params=mamba_outputs.cache_params,
hidden_states=mamba_outputs.hidden_states,
)

View File

@ -5022,6 +5022,30 @@ class M2M100PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MambaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MambaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MambaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MarianForCausalLM(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -307,6 +307,27 @@ def is_torch_cuda_available():
return False
def is_mamba_ssm_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
else:
return _is_package_available("mamba_ssm")
return False
def is_causal_conv1d_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
return _is_package_available("causal_conv1d")
return False
def is_torch_mps_available():
if is_torch_available():
import torch

View File

View File

@ -0,0 +1,491 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
from typing import Dict, List, Tuple
from unittest.util import safe_repr
from parameterized import parameterized
from transformers import AutoTokenizer, MambaConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
MambaForCausalLM,
MambaModel,
)
from transformers.models.mamba.modeling_mamba import MambaCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
else:
is_torch_greater_or_equal_than_2_0 = False
class MambaModelTester:
def __init__(
self,
parent,
batch_size=14,
seq_length=7,
is_training=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
intermediate_size=32,
hidden_act="silu",
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
scope=None,
tie_word_embeddings=True,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
self.tie_word_embeddings = tie_word_embeddings
def get_large_model_config(self):
return MambaConfig.from_pretrained("hf-internal-testing/mamba-2.8b")
def prepare_config_and_inputs(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(
gradient_checkpointing=gradient_checkpointing,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
reorder_and_upcast_attn=reorder_and_upcast_attn,
)
return (
config,
input_ids,
None,
sequence_labels,
token_labels,
choice_labels,
)
def get_config(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
return MambaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
intermediate_size=self.intermediate_size,
activation_function=self.hidden_act,
n_positions=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
tie_word_embeddings=self.tie_word_embeddings,
)
def get_pipeline_config(self):
config = self.get_config()
config.vocab_size = 300
return config
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
return (
config,
input_ids,
sequence_labels,
token_labels,
choice_labels,
)
def create_and_check_mamba_model(self, config, input_ids, *args):
config.output_hidden_states = True
model = MambaModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
def create_and_check_causl_lm(self, config, input_ids, *args):
model = MambaForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_state_equivalency(self, config, input_ids, *args):
model = MambaModel(config=config)
model.to(torch_device)
model.eval()
outputs = model(input_ids)
output_whole = outputs.last_hidden_state
outputs = model(input_ids[:, :-1], use_cache=True)
output_one = outputs.last_hidden_state
# Using the state computed on the first inputs, we will get the same output
outputs = model(input_ids[:, -1:], cache_params=outputs.cache_params)
output_two = outputs.last_hidden_state
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
# TODO the orignal mamba does not support decoding more than 1 token neither do we
def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False):
model = MambaForCausalLM(config)
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
result = model(input_ids, labels=input_ids)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
result.loss.backward()
def prepare_config_and_inputs_for_common(self):
(
config,
input_ids,
_,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
@unittest.skipIf(
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
)
@require_torch
class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else ()
fx_compatible = False # FIXME let's try to support this @ArthurZucker
test_torchscript = False # FIXME let's try to support this @ArthurZucker
test_missing_keys = False
test_model_parallel = False
test_pruning = False
test_head_masking = False # Mamba does not have attention heads
test_model_parallel = False
pipeline_model_mapping = (
{"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {}
)
def setUp(self):
self.model_tester = MambaModelTester(self)
self.config_tester = ConfigTester(
self, config_class=MambaConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)
def assertInterval(self, member, container, msg=None):
r"""
Simple utility function to check if a member is inside an interval.
"""
if isinstance(member, torch.Tensor):
max_value, min_value = member.max().item(), member.min().item()
elif isinstance(member, list) or isinstance(member, tuple):
max_value, min_value = max(member), min(member)
if not isinstance(container, list):
raise TypeError("container should be a list or tuple")
elif len(container) != 2:
raise ValueError("container should have 2 elements")
expected_min, expected_max = container
is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max)
if not is_inside_interval:
standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container))
self.fail(self._formatMessage(msg, standardMsg))
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("No attention in mamba")
def test_retain_grad_hidden_states_attentions(self):
pass
@require_torch_multi_gpu
def test_multi_gpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["cache_params"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))
def test_mamba_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba_model(*config_and_inputs)
def test_mamba_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_causl_lm(*config_and_inputs)
def test_state_equivalency(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config=config)
for name, param in model.named_parameters():
if "dt_proj.bias" in name:
dt = torch.exp(
torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min))
+ math.log(config.time_step_min)
).clamp(min=config.time_step_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt))
if param.requires_grad:
self.assertTrue(param.data.max().item() <= inv_dt[1])
self.assertTrue(param.data.min().item() >= inv_dt[0])
elif "A_log" in name:
A = torch.arange(1, config.state_size + 1, dtype=torch.float32)[None, :]
self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
elif "D" in name:
if param.requires_grad:
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
@unittest.skip("Mamba does not use attention")
def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
@slow
def test_model_from_pretrained(self):
model = MambaModel.from_pretrained("hf-internal-testing/mamba-130m")
self.assertIsNotNone(model)
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, MambaCache): # MODIFIED PART START
recursive_check(tuple_object.conv_states, dict_object.conv_states)
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(tuple_object, dict_object, atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@require_torch
class MambaIntegrationTests(unittest.TestCase):
def setUp(self):
self.model_id = "ArthurZ/mamba-2.8b"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@parameterized.expand([(torch_device,), ("cpu",)])
def test_simple_generate(self, device):
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16)
model.to(device)
model.config.use_cache = True
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
out = model.generate(input_ids, do_sample=False, max_new_tokens=10)
output_sentence = tokenizer.decode(out[0, :])
self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.")
with torch.no_grad():
logits = model(input_ids=input_ids).logits
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
-55.6875, -69.8750, -49.9062, -51.7500, -57.6875, -57.9375, -56.9688,
-57.9375, -54.6875, -55.9375, -55.3125, -58.0938, -60.5625, -47.0000,
-52.0312, -49.7812, -55.9375, -57.9062, -56.7812, -57.1250, -57.3438,
-58.3125, -57.8125, -58.7812, -59.6250, -59.0938, -58.7188, -52.9375,
-53.4688, -57.3750, -56.9375, -55.7500, -53.3125, -55.8438, -57.0000,
-56.9062, -56.2188, -54.7188, -56.4375, -57.5000
]
,dtype=torch.float32) # fmt: skip
torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
@parameterized.expand([(torch_device,), ("cpu",)])
def test_simple_generate_cuda_kernels_tiny(self, device):
expected_output = "Hello my name is John and I am a newbie to the world"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
@parameterized.expand([(torch_device,), ("cpu",)])
@slow
def test_simple_generate_cuda_kernels_small(self, device):
expected_output = "Hello my name is\n\nI am a\n\nI am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
@parameterized.expand([(torch_device,), ("cpu",)])
@slow
def test_simple_generate_cuda_kernels_mid(self, device):
expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)
@parameterized.expand([(torch_device,), ("cpu",)])
@slow
def test_simple_generate_cuda_kernels_big(self, device):
expected_output = "Hello my name is John and I am a new member of this forum. I am a retired Marine and I am a member of the Marine Corps League. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=30)
output_sentence = self.tokenizer.decode(output[0].tolist())
self.assertEqual(output_sentence, expected_output)

View File

@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used as in the config to define `intermediate_size`
"MambaConfig": ["expand"],
# used as `self.bert_model = BertModel(config, ...)`
"DPRConfig": True,
"FuyuConfig": True,