mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add DBRX Model (#29921)
* wip * fix __init__.py * add docs * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * address comments 1 * work on make fixup * pass configs down * add sdpa attention * remove DbrxBlock * add to configuration_auto * docstring now passes formatting test * fix style * update READMEs * add dbrx to modeling_auto * make fix-copies generated this * add DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP * config docstring passes formatting test * rename moe_loss_weight to router_aux_loss_coef * add to flash-attn documentation * fix model-path in tests * Explicitly make `"suli"` the default `ffn_act_fn` Co-authored-by: Wing Lian <wing.lian@gmail.com> * default to using router_aux_loss_coef over ffn_config[moe_loss_weight] * fix _flash_attn_uses_top_left_mask and is_causal * fix tests path * don't use token type IDs * follow Llama and remove token_type_ids from test * init ConfigTester differently so tests pass * remove multiple choice test * remove question + answer test * remove sequence classification test * remove token classification test * copy Llama tests and remove token_type_ids from test inputs * do not test pruning or headmasking; style code * add _tied_weights_keys parameter to pass test * add type hints * fix type check * update config tester * remove masked_lm test * remove encoder tests * initialize DbrxModelTester with correct params * style * torch_dtype does not rely on torch * run make fixup, fix-copies * use https://huggingface.co/v2ray/dbrx-base-fixed/blob/main/modeling_dbrx.py * add copyright info * fix imports and DbrxRotaryEmbedding * update DbrxModel docstring * use copies * change model path in docstring * use config in DbrxFFN * fix flashattention2, sdpaattention * input config to DbrXAttention, DbrxNormAttentionNorm * more fixes * fix * fix again! * add informative comment * fix ruff? * remove print statement + style * change doc-test * fix doc-test * fix docstring * delete commented out text * make defaults match dbrx-instruct * replace `router_aux_loss_coef` with `moe_loss_weight` * is_decoder=True * remove is_decoder from configtester * implement sdpa properly * make is_decoder pass tests * start on the GenerationTesterMixin tests * add dbrx to sdpa documentation * skip weight typing test * style * initialize smaller model Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Add DBRX to toctree * skip test_new_cache_format * make config defaults smaller again * add pad_token_id * remove pad_token_id from config * Remove all references to DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP * Update src/transformers/models/dbrx/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/dbrx/modeling_dbrx.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/dbrx.md Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Update src/transformers/models/dbrx/configuration_dbrx.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/dbrx.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix typo * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update docs, fix configuration_auto.py * address pr comments * remove is_decoder flag * slice * fix requires grad * remove grad * disconnect differently * remove grad * enable grads * patch * detach expert * nissan al ghaib * Update modeling_dbrx.py * Update src/transformers/models/dbrx/modeling_dbrx.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * replace "Gemma" with "Dbrx" * remove # type: ignore * don't hardcode vocab_size * remove ToDo * Re-add removed idefics2 line * Update test to use tiny-random! * Remove TODO * Remove one more case of loading the entire dbrx-instruct in the tests * Update src/transformers/models/dbrx/modeling_dbrx.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * address some comments * small model * add dbrx to tokenization_auto * More docstrings with add_start_docstrings * Dbrx for now * add PipelineTesterMixin * Update src/transformers/models/dbrx/configuration_dbrx.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove flash-attn2 import error * fix docstring Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add useage example * put on one line Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix ffn_act_fn Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * change "dbrx" to "DBRX" for display purposes. * fix __init__.py? * fix __init__.py * fix README * return the aux_loss * remove extra spaces * fix configuration_auto.py * fix format in tokenization_auto * remove new line * add more useage examples --------- Co-authored-by: Abhi Venigalla <abhi.venigalla@databricks.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Eitan Turok <eitan.turok@databricks.com> Co-authored-by: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Eitan Turok <eitanturok@gmail.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
63c5e27efb
commit
005b957fb8
@ -341,6 +341,7 @@ Current number of checkpoints: ** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -337,6 +337,7 @@ Aktuelle Anzahl der Checkpoints: ** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -314,6 +314,7 @@ Número actual de puntos de control: ** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
@ -477,7 +478,7 @@ Número actual de puntos de control: ** (from Tel Aviv University), released together with the paper [Few-Shot Question Answering by Pretraining Span Selection](https://arxiv.org/abs/2101.00438) by Ori Ram, Yuval Kirstain, Jonathan Berant, Amir Globerson, Omer Levy.
|
||||
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
1. **[StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm)** (from Stability AI) released with the paper [StableLM 3B 4E1T (Technical Report)](https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo) by Jonathan Tow, Marco Bellagente, Dakota Mahan, Carlos Riquelme Ruiz, Duy Phung, Maksym Zhuravinskyi, Nathan Cooper, Nikhil Pinnaparaju, Reshinth Adithyan, and James Baicoianu.
|
||||
1. **[Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2)** (from BigCode team) released with a coming soon paper.
|
||||
1. **[Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2)** (from BigCode team) released with the paper [StarCoder 2 and The Stack v2: The Next Generation](https://arxiv.org/abs/2402.19173) by Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, Zhuang Li, Wen-Ding Li, Megan Risdal, Jia Li, Jian Zhu, Terry Yue Zhuo, Evgenii Zheltonozhskii, Nii Osae Osae Dade, Wenhao Yu, Lucas Krauß, Naman Jain, Yixuan Su, Xuanli He, Manan Dey, Edoardo Abati, Yekun Chai, Niklas Muennighoff, Xiangru Tang, Muhtasham Oblokulov, Christopher Akiki, Marc Marone, Chenghao Mou, Mayank Mishra, Alex Gu, Binyuan Hui, Tri Dao, Armel Zebaze, Olivier Dehaene, Nicolas Patry, Canwen Xu, Julian McAuley, Han Hu, Torsten Scholak, Sebastien Paquet, Jennifer Robinson, Carolyn Jane Anderson, Nicolas Chapados, Mostofa Patwary, Nima Tajbakhsh, Yacine Jernite, Carlos Muñoz Ferrandis, Lingming Zhang, Sean Hughes, Thomas Wolf, Arjun Guha, Leandro von Werra, and Harm de Vries.
|
||||
1. **[SuperPoint](https://huggingface.co/docs/transformers/model_doc/superpoint)** (from MagicLeap) released with the paper [SuperPoint: Self-Supervised Interest Point Detection and Description](https://arxiv.org/abs/1712.07629) by Daniel DeTone, Tomasz Malisiewicz and Andrew Rabinovich.
|
||||
1. **[SwiftFormer](https://huggingface.co/docs/transformers/model_doc/swiftformer)** (from MBZUAI) released with the paper [SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications](https://arxiv.org/abs/2303.15446) by Abdelrahman Shaker, Muhammad Maaz, Hanoona Rasheed, Salman Khan, Ming-Hsuan Yang, Fahad Shahbaz Khan.
|
||||
1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.
|
||||
|
@ -335,6 +335,7 @@ Nombre actuel de points de contrôle : ** (de Salesforce) publié dans l'article [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) par Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong et Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (de Microsoft) publié dans l'article [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) par Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (de Facebook) publié dans l'article [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) par Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (de Microsoft) publié dans l'article [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) par Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (de Microsoft) publié dans l'article [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) par Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (de Berkeley/Facebook/Google) publié dans l'article [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) par Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -288,6 +288,7 @@ conda install conda-forge::transformers
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (सेल्सफोर्स से) साथ में पेपर [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) नीतीश शिरीष केसकर*, ब्रायन मैककैन*, लव आर. वार्ष्णेय, कैमिंग जिओंग और रिचर्ड द्वारा सोचर द्वारा जारी किया गया।
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (Microsoft से) साथ में दिया गया पेपर [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) हैपिंग वू, बिन जिओ, नोएल कोडेला, मेंगचेन लियू, जियांग दाई, लू युआन, लेई झांग द्वारा।
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (फेसबुक से) साथ में कागज [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) एलेक्सी बाएव्स्की, वेई-निंग सू, कियानटोंग जू, अरुण बाबू, जियाताओ गु, माइकल औली द्वारा पोस्ट किया गया।
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (Microsoft से) साथ में दिया गया पेपर [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) पेंगचेंग हे, ज़ियाओडोंग लियू, जियानफेंग गाओ, वीज़ू चेन द्वारा।
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (Microsoft से) साथ में दिया गया पेपर [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) पेंगचेंग हे, ज़ियाओडोंग लियू, जियानफेंग गाओ, वीज़ू चेन द्वारा पोस्ट किया गया।
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (बर्कले/फेसबुक/गूगल से) पेपर के साथ [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) लिली चेन, केविन लू, अरविंद राजेश्वरन, किमिन ली, आदित्य ग्रोवर, माइकल लास्किन, पीटर एबील, अरविंद श्रीनिवास, इगोर मोर्डच द्वारा पोस्ट किया गया।
|
||||
|
@ -348,6 +348,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (Salesforce から) Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher から公開された研究論文: [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858)
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (Microsoft から) Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang から公開された研究論文: [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808)
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (Facebook から) Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli から公開された研究論文: [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555)
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (Microsoft から) Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen から公開された研究論文: [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654)
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (Microsoft から) Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen から公開された研究論文: [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654)
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (Berkeley/Facebook/Google から) Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch から公開された研究論文: [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345)
|
||||
|
@ -263,6 +263,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (Salesforce 에서) Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher 의 [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) 논문과 함께 발표했습니다.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (Microsoft 에서) Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang 의 [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) 논문과 함께 발표했습니다.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (Facebook 에서) Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli 의 [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) 논문과 함께 발표했습니다.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (Microsoft 에서) Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 의 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 논문과 함께 발표했습니다.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (Microsoft 에서) Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 의 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 논문과 함께 발표했습니다.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (Berkeley/Facebook/Google 에서) Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch 의 [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) 논문과 함께 발표했습니다.
|
||||
|
@ -346,6 +346,7 @@ Número atual de pontos de verificação: ** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -336,6 +336,7 @@ conda install conda-forge::transformers
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -338,6 +338,7 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్స్టా
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -337,6 +337,7 @@ Số lượng điểm kiểm tra hiện tại: ** (từ Salesforce) được phát hành với bài báo [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (từ Microsoft) được phát hành với bài báo [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (từ Facebook) được phát hành với bài báo [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (từ Microsoft) được phát hành với bài báo [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (từ Microsoft) được phát hành với bài báo [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (từ Berkeley/Facebook/Google) được phát hành với bài báo [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -287,6 +287,7 @@ conda install conda-forge::transformers
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (来自 Salesforce) 伴随论文 [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) 由 Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher 发布。
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (来自 Microsoft) 伴随论文 [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) 由 Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang 发布。
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (来自 Facebook) 伴随论文 [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) 由 Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli 发布。
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (来自 Microsoft) 伴随论文 [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) 由 Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen 发布。
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (来自 Berkeley/Facebook/Google) 伴随论文 [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) 由 Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch 发布。
|
||||
|
@ -299,6 +299,7 @@ conda install conda-forge::transformers
|
||||
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[CvT](https://huggingface.co/docs/transformers/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
|
||||
1. **[Data2Vec](https://huggingface.co/docs/transformers/model_doc/data2vec)** (from Facebook) released with the paper [Data2Vec: A General Framework for Self-supervised Learning in Speech, Vision and Language](https://arxiv.org/abs/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli.
|
||||
1. **[DBRX](https://huggingface.co/docs/transformers/main/model_doc/dbrx)** (from Databricks) released with the paper [Introducing DBRX: A New State-of-the-Art Open LLM](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) by the Mosaic Research Team.
|
||||
1. **[DeBERTa](https://huggingface.co/docs/transformers/model_doc/deberta)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[DeBERTa-v2](https://huggingface.co/docs/transformers/model_doc/deberta-v2)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
1. **[Decision Transformer](https://huggingface.co/docs/transformers/model_doc/decision_transformer)** (from Berkeley/Facebook/Google) released with the paper [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) by Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas, Igor Mordatch.
|
||||
|
@ -320,6 +320,8 @@
|
||||
title: CPMANT
|
||||
- local: model_doc/ctrl
|
||||
title: CTRL
|
||||
- local: model_doc/dbrx
|
||||
title: DBRX
|
||||
- local: model_doc/deberta
|
||||
title: DeBERTa
|
||||
- local: model_doc/deberta-v2
|
||||
|
@ -107,6 +107,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [Data2VecAudio](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
||||
| [Data2VecText](model_doc/data2vec) | ✅ | ❌ | ❌ |
|
||||
| [Data2VecVision](model_doc/data2vec) | ✅ | ✅ | ❌ |
|
||||
| [DBRX](model_doc/dbrx) | ✅ | ❌ | ❌ |
|
||||
| [DeBERTa](model_doc/deberta) | ✅ | ✅ | ❌ |
|
||||
| [DeBERTa-v2](model_doc/deberta-v2) | ✅ | ✅ | ❌ |
|
||||
| [Decision Transformer](model_doc/decision_transformer) | ✅ | ❌ | ❌ |
|
||||
|
120
docs/source/en/model_doc/dbrx.md
Normal file
120
docs/source/en/model_doc/dbrx.md
Normal file
@ -0,0 +1,120 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DBRX
|
||||
|
||||
## Overview
|
||||
|
||||
DBRX is a [transformer-based](https://www.isattentionallyouneed.com/) decoder-only large language model (LLM) that was trained using next-token prediction.
|
||||
It uses a *fine-grained* mixture-of-experts (MoE) architecture with 132B total parameters of which 36B parameters are active on any input.
|
||||
It was pre-trained on 12T tokens of text and code data.
|
||||
Compared to other open MoE models like Mixtral-8x7B and Grok-1, DBRX is fine-grained, meaning it uses a larger number of smaller experts. DBRX has 16 experts and chooses 4, while Mixtral-8x7B and Grok-1 have 8 experts and choose 2.
|
||||
This provides 65x more possible combinations of experts and we found that this improves model quality.
|
||||
DBRX uses rotary position encodings (RoPE), gated linear units (GLU), and grouped query attention (GQA).
|
||||
It is a BPE based model and uses the GPT-4 tokenizer as described in the [tiktoken](https://github.com/openai/tiktoken) repository.
|
||||
We made these choices based on exhaustive evaluation and scaling experiments.
|
||||
|
||||
DBRX was pretrained on 12T tokens of carefully curated data and a maximum context length of 32K tokens.
|
||||
We estimate that this data is at least 2x better token-for-token than the data we used to pretrain the MPT family of models.
|
||||
This new dataset was developed using the full suite of Databricks tools, including Apache Spark™ and Databricks notebooks for data processing, and Unity Catalog for data management and governance.
|
||||
We used curriculum learning for pretraining, changing the data mix during training in ways we found to substantially improve model quality.
|
||||
|
||||
|
||||
More detailed information about DBRX Instruct and DBRX Base can be found in our [technical blog post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm).
|
||||
|
||||
|
||||
This model was contributed by [eitan-turok](https://huggingface.co/eitanturok) and [abhi-db](https://huggingface.co/abhi-db). The original code can be found [here](https://github.com/databricks/dbrx-instruct).
|
||||
|
||||
## Usage Examples
|
||||
|
||||
The `generate()` method can be used to generate text using DBRX. You can generate using the standard attention implementation, flash-attention, and the PyTorch scaled dot product attention. The last two attention implementations give speed ups.
|
||||
|
||||
```python
|
||||
from transformers import DbrxForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
|
||||
model = DbrxForCausalLM.from_pretrained(
|
||||
"databricks/dbrx-instruct",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
token="YOUR_HF_TOKEN",
|
||||
)
|
||||
|
||||
input_text = "What does it take to build a great LLM?"
|
||||
messages = [{"role": "user", "content": input_text}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=200)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
If you have flash-attention installed (`pip install flash-attn`), it is possible to generate faster. (The HuggingFace documentation for flash-attention can be found [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2).)
|
||||
```python
|
||||
from transformers import DbrxForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
|
||||
model = DbrxForCausalLM.from_pretrained(
|
||||
"databricks/dbrx-instruct",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
token="YOUR_HF_TOKEN",
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
input_text = "What does it take to build a great LLM?"
|
||||
messages = [{"role": "user", "content": input_text}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=200)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
You can also generate faster using the PyTorch scaled dot product attention. (The HuggingFace documentation for scaled dot product attention can be found [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#pytorch-scaled-dot-product-attention).)
|
||||
```python
|
||||
from transformers import DbrxForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct", token="YOUR_HF_TOKEN")
|
||||
model = DbrxForCausalLM.from_pretrained(
|
||||
"databricks/dbrx-instruct",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
token="YOUR_HF_TOKEN",
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
input_text = "What does it take to build a great LLM?"
|
||||
messages = [{"role": "user", "content": input_text}]
|
||||
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
|
||||
|
||||
outputs = model.generate(**input_ids, max_new_tokens=200)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
## DbrxConfig
|
||||
|
||||
[[autodoc]] DbrxConfig
|
||||
|
||||
|
||||
## DbrxModel
|
||||
|
||||
[[autodoc]] DbrxModel
|
||||
- forward
|
||||
|
||||
|
||||
## DbrxForCausalLM
|
||||
|
||||
[[autodoc]] DbrxForCausalLM
|
||||
- forward
|
||||
|
@ -40,6 +40,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
@ -184,9 +185,10 @@ PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.o
|
||||
For now, Transformers supports SDPA inference and training for the following architectures:
|
||||
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
|
||||
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
|
||||
|
@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the
|
||||
Choose one of the following architectures:
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [Jamba](../model_doc/jamba), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OLMo](../model_doc/olmo), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [Cohere](../model_doc/cohere), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DBRX](../model_doc/dbrx), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [Gemma](../model_doc/gemma), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [Jamba](../model_doc/jamba), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MusicGen Melody](../model_doc/musicgen_melody), [MVP](../model_doc/mvp), [OLMo](../model_doc/olmo), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Qwen2MoE](../model_doc/qwen2_moe), [RecurrentGemma](../model_doc/recurrent_gemma), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
|
||||
|
||||
|
||||
|
@ -328,6 +328,7 @@ _import_structure = {
|
||||
"Data2VecTextConfig",
|
||||
"Data2VecVisionConfig",
|
||||
],
|
||||
"models.dbrx": ["DbrxConfig"],
|
||||
"models.deberta": [
|
||||
"DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"DebertaConfig",
|
||||
@ -1943,6 +1944,13 @@ else:
|
||||
"Data2VecVisionPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.dbrx"].extend(
|
||||
[
|
||||
"DbrxForCausalLM",
|
||||
"DbrxModel",
|
||||
"DbrxPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.deberta"].extend(
|
||||
[
|
||||
"DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -5268,6 +5276,7 @@ if TYPE_CHECKING:
|
||||
Data2VecTextConfig,
|
||||
Data2VecVisionConfig,
|
||||
)
|
||||
from .models.dbrx import DbrxConfig
|
||||
from .models.deberta import (
|
||||
DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DebertaConfig,
|
||||
@ -6782,6 +6791,13 @@ if TYPE_CHECKING:
|
||||
Data2VecVisionModel,
|
||||
Data2VecVisionPreTrainedModel,
|
||||
)
|
||||
|
||||
# PyTorch model imports
|
||||
from .models.dbrx import (
|
||||
DbrxForCausalLM,
|
||||
DbrxModel,
|
||||
DbrxPreTrainedModel,
|
||||
)
|
||||
from .models.deberta import (
|
||||
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DebertaForMaskedLM,
|
||||
|
@ -59,6 +59,7 @@ from . import (
|
||||
ctrl,
|
||||
cvt,
|
||||
data2vec,
|
||||
dbrx,
|
||||
deberta,
|
||||
deberta_v2,
|
||||
decision_transformer,
|
||||
|
@ -77,6 +77,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-audio", "Data2VecAudioConfig"),
|
||||
("data2vec-text", "Data2VecTextConfig"),
|
||||
("data2vec-vision", "Data2VecVisionConfig"),
|
||||
("dbrx", "DbrxConfig"),
|
||||
("deberta", "DebertaConfig"),
|
||||
("deberta-v2", "DebertaV2Config"),
|
||||
("decision_transformer", "DecisionTransformerConfig"),
|
||||
@ -340,6 +341,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("data2vec-audio", "Data2VecAudio"),
|
||||
("data2vec-text", "Data2VecText"),
|
||||
("data2vec-vision", "Data2VecVision"),
|
||||
("dbrx", "DBRX"),
|
||||
("deberta", "DeBERTa"),
|
||||
("deberta-v2", "DeBERTa-v2"),
|
||||
("decision_transformer", "Decision Transformer"),
|
||||
|
@ -75,6 +75,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-audio", "Data2VecAudioModel"),
|
||||
("data2vec-text", "Data2VecTextModel"),
|
||||
("data2vec-vision", "Data2VecVisionModel"),
|
||||
("dbrx", "DbrxModel"),
|
||||
("deberta", "DebertaModel"),
|
||||
("deberta-v2", "DebertaV2Model"),
|
||||
("decision_transformer", "DecisionTransformerModel"),
|
||||
@ -439,6 +440,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("cpmant", "CpmAntForCausalLM"),
|
||||
("ctrl", "CTRLLMHeadModel"),
|
||||
("data2vec-text", "Data2VecTextForCausalLM"),
|
||||
("dbrx", "DbrxForCausalLM"),
|
||||
("electra", "ElectraForCausalLM"),
|
||||
("ernie", "ErnieForCausalLM"),
|
||||
("falcon", "FalconForCausalLM"),
|
||||
|
@ -150,6 +150,7 @@ else:
|
||||
("ctrl", ("CTRLTokenizer", None)),
|
||||
("data2vec-audio", ("Wav2Vec2CTCTokenizer", None)),
|
||||
("data2vec-text", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("dbrx", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
|
||||
("deberta", ("DebertaTokenizer", "DebertaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
(
|
||||
"deberta-v2",
|
||||
|
51
src/transformers/models/dbrx/__init__.py
Normal file
51
src/transformers/models/dbrx/__init__.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_dbrx": ["DbrxConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_dbrx"] = [
|
||||
"DbrxForCausalLM",
|
||||
"DbrxModel",
|
||||
"DbrxPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_dbrx import DbrxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_dbrx import DbrxForCausalLM, DbrxModel, DbrxPreTrainedModel
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
257
src/transformers/models/dbrx/configuration_dbrx.py
Normal file
257
src/transformers/models/dbrx/configuration_dbrx.py
Normal file
@ -0,0 +1,257 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" DBRX model configuration """
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx Attention.
|
||||
|
||||
[`DbrxAttention`] class. It is used to instantiate attention layers
|
||||
according to the specified arguments, defining the layers architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the attention layers.
|
||||
clip_qkv (`float`, *optional*):
|
||||
If set, clip the queries, keys, and values in the attention layer to this value.
|
||||
kv_n_heads (`Optional[int]`, defaults to 1): For grouped_query_attention only, allow user to specify number of kv heads.
|
||||
rope_theta (`float`, defaults to 10000.0): The base frequency for rope.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_pdrop: float = 0.0,
|
||||
clip_qkv: Optional[float] = None,
|
||||
kv_n_heads: int = 1,
|
||||
rope_theta: float = 10000.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.clip_qkv = clip_qkv
|
||||
self.kv_n_heads = kv_n_heads
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
for k in ["model_type"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(f"Found unknown {kwargs=}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "dbrx":
|
||||
config_dict = config_dict["attn_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class DbrxFFNConfig(PretrainedConfig):
|
||||
"""Configuration class for Dbrx FFN.
|
||||
|
||||
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
||||
the specified arguments, defining the layers architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
ffn_act_fn (`dict`, *optional*, defaults to `None`): A dict specifying activation function for the FFN.
|
||||
The dict should have a key 'name' with the value being the name of the activation function along with
|
||||
any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
|
||||
ffn_hidden_size (`int`, defaults to 3584): The hidden size of the feedforward network.
|
||||
moe_num_experts (`int`, defaults to 4): The number of experts in the mixture of experts layer.
|
||||
moe_top_k (`int`, defaults to 1): The number of experts to use in the mixture of experts layer.
|
||||
moe_jitter_eps (`float`, *optional*, defaults to `None`): If not `None`, the jitter epsilon for the mixture of experts layer.
|
||||
moe_loss_weight (`float`, defaults to 0.01): The loss weight for the mixture of experts layer.
|
||||
moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0): The normalization factor for the expert weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ffn_act_fn: dict = None,
|
||||
ffn_hidden_size: int = 3584,
|
||||
moe_num_experts: int = 4,
|
||||
moe_top_k: int = 1,
|
||||
moe_jitter_eps: Optional[float] = None,
|
||||
moe_loss_weight: float = 0.01,
|
||||
moe_normalize_expert_weights: Optional[float] = 1.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
if ffn_act_fn is None:
|
||||
ffn_act_fn = {"name": "silu"}
|
||||
self.ffn_act_fn = ffn_act_fn
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.moe_jitter_eps = moe_jitter_eps
|
||||
self.moe_loss_weight = moe_loss_weight
|
||||
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
||||
|
||||
for k in ["model_type"]:
|
||||
if k in kwargs:
|
||||
kwargs.pop(k)
|
||||
if len(kwargs) != 0:
|
||||
raise ValueError(f"Found unknown {kwargs=}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> "PretrainedConfig":
|
||||
cls._set_token_in_kwargs(kwargs)
|
||||
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if config_dict.get("model_type") == "dbrx":
|
||||
config_dict = config_dict["ffn_config"]
|
||||
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class DbrxConfig(PretrainedConfig):
|
||||
r"""
|
||||
|
||||
This is the configuration class to store the configuration of a [`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
||||
specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a different configuration to that of the [databricks/dbrx-instruct](https://huggingface.co/databricks/dbrx-instruct) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
d_model (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
n_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
max_seq_len (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length of the model.
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
||||
the `inputs_ids` passed when calling [`DbrxModel`].
|
||||
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability applied to the attention output before combining with residual.
|
||||
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for the embedding layer.
|
||||
attn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's attention module.
|
||||
ffn_config (`dict`, *optional*):
|
||||
A dictionary used to configure the model's FFN module.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss. See [here]() for more details.
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import DbrxConfig, DbrxModel
|
||||
|
||||
>>> # Initializing a Dbrx configuration
|
||||
>>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = DbrxModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "dbrx"
|
||||
attribute_map = {
|
||||
"num_attention_heads": "n_heads",
|
||||
"hidden_size": "d_model",
|
||||
"num_hidden_layers": "n_layers",
|
||||
"max_position_embeddings": "max_seq_len",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 2048,
|
||||
n_heads: int = 16,
|
||||
n_layers: int = 24,
|
||||
max_seq_len: int = 2048,
|
||||
vocab_size: int = 32000,
|
||||
resid_pdrop: float = 0.0,
|
||||
emb_pdrop: float = 0.0,
|
||||
attn_config: Optional[DbrxAttentionConfig] = None,
|
||||
ffn_config: Optional[DbrxFFNConfig] = None,
|
||||
use_cache: bool = True,
|
||||
initializer_range: float = 0.02,
|
||||
output_router_logits: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if attn_config is None:
|
||||
self.attn_config = DbrxAttentionConfig()
|
||||
elif isinstance(attn_config, dict):
|
||||
self.attn_config = DbrxAttentionConfig(**attn_config)
|
||||
else:
|
||||
self.attn_config = attn_config
|
||||
|
||||
if ffn_config is None:
|
||||
self.ffn_config = DbrxFFNConfig()
|
||||
elif isinstance(ffn_config, dict):
|
||||
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
||||
else:
|
||||
self.ffn_config = ffn_config
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab_size = vocab_size
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.emb_pdrop = emb_pdrop
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
self.output_router_logits = output_router_logits
|
||||
|
||||
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
raise ValueError("tie_word_embeddings is not supported for DBRX models.")
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
1523
src/transformers/models/dbrx/modeling_dbrx.py
Normal file
1523
src/transformers/models/dbrx/modeling_dbrx.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -2457,6 +2457,27 @@ class Data2VecVisionPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DbrxForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DbrxModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DbrxPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -97,8 +97,8 @@ src/transformers/models/<model_name>/configuration_<model_name>.py
|
||||
src/transformers/models/<model_name>/modeling_<model_name>.py
|
||||
src/transformers/models/<model_name>/modeling_tf_<model_name>.py
|
||||
src/transformers/models/<model_name>/tokenization_<model_name>.py
|
||||
tests/test_modeling_<model_name>.py
|
||||
tests/test_modeling_tf_<model_name>.py
|
||||
tests/models/<model_name>/test_modeling_<model_name>.py
|
||||
tests/models/<model_name>/test_modeling_tf_<model_name>.py
|
||||
```
|
||||
|
||||
You can run the tests to ensure that they all pass:
|
||||
|
0
tests/models/dbrx/__init__.py
Normal file
0
tests/models/dbrx/__init__.py
Normal file
387
tests/models/dbrx/test_modeling_dbrx.py
Normal file
387
tests/models/dbrx/test_modeling_dbrx.py
Normal file
@ -0,0 +1,387 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch DBRX model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import DbrxConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import DbrxForCausalLM, DbrxModel
|
||||
|
||||
|
||||
class DbrxModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
hidden_size=32,
|
||||
ffn_hidden_size=32,
|
||||
num_attention_heads=4,
|
||||
kv_n_heads=4,
|
||||
num_hidden_layers=5,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
use_cache=True,
|
||||
type_sequence_label_size=2,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
clip_qkv=8,
|
||||
rope_theta=500000,
|
||||
attn_config_model_type="",
|
||||
emb_pdrop=0.0,
|
||||
moe_jitter_eps=0,
|
||||
moe_loss_weight=0.05,
|
||||
moe_num_experts=16,
|
||||
moe_top_k=4,
|
||||
ffn_config_model_type="",
|
||||
ffn_act_fn_name="gelu",
|
||||
initializer_range=0.02,
|
||||
output_router_logits=False,
|
||||
resid_pdrop=0.0,
|
||||
tie_word_embeddings=False,
|
||||
torch_dtype="bfloat16",
|
||||
vocab_size=99,
|
||||
is_decoder=True,
|
||||
pad_token_id=0,
|
||||
):
|
||||
# Parameters unique to testing
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.parent = parent
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
|
||||
# attn_config params
|
||||
self.clip_qkv = clip_qkv
|
||||
self.kv_n_heads = kv_n_heads
|
||||
self.rope_theta = rope_theta
|
||||
self.attn_config_model_type = attn_config_model_type
|
||||
|
||||
# ffn_config params
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.moe_jitter_eps = moe_jitter_eps
|
||||
self.moe_loss_weight = moe_loss_weight
|
||||
self.moe_num_experts = moe_num_experts
|
||||
self.moe_top_k = moe_top_k
|
||||
self.ffn_config_model_type = ffn_config_model_type
|
||||
self.ffn_act_fn_name = ffn_act_fn_name
|
||||
|
||||
# Other model params
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.vocab_size = vocab_size
|
||||
self.use_cache = use_cache
|
||||
self.initializer_range = initializer_range
|
||||
self.emb_pdrop = emb_pdrop
|
||||
self.output_router_logits = output_router_logits
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.torch_dtype = torch_dtype
|
||||
self.is_decoder = is_decoder
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
# Make the dictionaries
|
||||
self.ffn_config = {
|
||||
"ffn_hidden_size": self.ffn_hidden_size,
|
||||
"moe_jitter_eps": self.moe_jitter_eps,
|
||||
"moe_loss_weight": self.moe_loss_weight,
|
||||
"moe_num_experts": self.moe_num_experts,
|
||||
"moe_top_k": self.moe_top_k,
|
||||
"model_type": self.ffn_config_model_type,
|
||||
"ffn_act_fn": {"name": self.ffn_act_fn_name},
|
||||
}
|
||||
self.attn_config = {
|
||||
"clip_qkv": self.clip_qkv,
|
||||
"kv_n_heads": self.kv_n_heads,
|
||||
"model_type": self.attn_config_model_type,
|
||||
"rope_theta": self.rope_theta,
|
||||
}
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
# Behind the scenes, `DbrxConfig` maps the parameters `hidden_size`, `num_hidden_layers`,
|
||||
# `num_attention_heads`, `max_position_embeddings` to the parameters `d_model`, `n_layers`,
|
||||
# `n_heads`, `max_seq_len` respectively. We use the first group of parameters because
|
||||
# other tests expect every model to have these parameters with these specific names.
|
||||
config = DbrxConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size, # mapped to `d_model`
|
||||
num_hidden_layers=self.num_hidden_layers, # mapped to `n_layers`
|
||||
num_attention_heads=self.num_attention_heads, # mapped to `n_heads`
|
||||
max_position_embeddings=self.max_position_embeddings, # mapped to `max_seq_len`
|
||||
attn_config=self.attn_config,
|
||||
ffn_config=self.ffn_config,
|
||||
resid_pdrop=self.resid_pdrop,
|
||||
emb_pdrop=self.emb_pdrop,
|
||||
use_cache=self.use_cache,
|
||||
initializer_range=self.initializer_range,
|
||||
output_router_logits=self.output_router_logits,
|
||||
is_decoder=self.is_decoder,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
return config
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Dbrx
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DbrxModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Dbrx
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = DbrxModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Dbrx
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = DbrxForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = DbrxForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Dbrx
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (DbrxModel, DbrxForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (DbrxForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = {"text-generation": DbrxForCausalLM} if is_torch_available() else {}
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DbrxModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DbrxConfig, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "eitanturok/dbrx-tiny"
|
||||
model = DbrxModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@unittest.skip("Dbrx models have weight tying disabled.")
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TODO @gante fix this for Llama")
|
||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
||||
def test_new_cache_format(self, num_beams, do_sample):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class DbrxModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_tiny_model_logits(self):
|
||||
model = DbrxForCausalLM.from_pretrained("Rocketknight1/dbrx-tiny-random")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
vocab_size = model.vocab_size
|
||||
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-1.6300e-04, 5.0118e-04, 2.5437e-04],
|
||||
[2.0422e-05, 2.7210e-04, -1.5125e-04],
|
||||
[-1.5105e-04, 4.6879e-04, 3.3309e-04],
|
||||
]
|
||||
]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user