mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Add RWKV-4 (#22797)
* First draft of RWKV-4 * Add support for generate * Style post-rebase * Properly use state * Write doc * Fix doc * More math * Add model to README, dummies and clean config * Fix init * multiple fixes: - fix common tests - fix configuraion default values - add CI test for checking state computation - fix some CI tests * correct tokenizer * some tweaks - fix config docstring - fix failing tests * fix CI tests - add output_attention / output_hidden_states - override test_initialization - fix failing CIs * fix conversion script - fix sharded case - add new arguments * add slow tests + more fixes on conversion script * add another test * final fixes * change single name variable * add mock attention mask for pipeline to work * correct eos token id * fix nits * add checkpoints * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * add `tie_word_embeddings` in docstring * change tensor name * fix final nits * Trigger CI --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
9a50cb6195
commit
b4d4d6fe87
@ -422,6 +422,7 @@ Current number of checkpoints: ** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli.
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
|
||||
|
@ -410,6 +410,7 @@ Número actual de puntos de control: ** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli.
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
|
||||
|
@ -382,6 +382,7 @@ conda install -c huggingface transformers
|
||||
1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli.
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (झुईई टेक्नोलॉजी से), साथ में पेपर [रोफॉर्मर: रोटरी पोजिशन एंबेडिंग के साथ एन्हांस्ड ट्रांसफॉर्मर] (https://arxiv.org/pdf/2104.09864v1.pdf) जियानलिन सु और यू लू और शेंगफेंग पैन और बो वेन और युनफेंग लियू द्वारा प्रकाशित।
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Bo Peng से) Bo Peng. द्वाराअनुसंधान पत्र [this repo](https://github.com/BlinkDL/RWKV-LM) के साथ जारी किया गया
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI से) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. द्वाराअनुसंधान पत्र [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) के साथ जारी किया गया
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP से) साथ देने वाला पेपर [भाषण पहचान के लिए अनसुपरवाइज्ड प्री-ट्रेनिंग में परफॉर्मेंस-एफिशिएंसी ट्रेड-ऑफ्स](https ://arxiv.org/abs/2109.06870) फेलिक्स वू, क्वांगयुन किम, जिंग पैन, क्यू हान, किलियन क्यू. वेनबर्गर, योव आर्टज़ी द्वारा।
|
||||
|
@ -444,6 +444,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ
|
||||
1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (Facebook から) Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli から公開された研究論文: [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038)
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI から) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou から公開された研究論文: [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf)
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology から), Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu から公開された研究論文: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Bo Peng から) Bo Peng. から公開された研究論文 [this repo](https://github.com/BlinkDL/RWKV-LM)
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA から) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo から公開された研究論文: [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI から) Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick. から公開された研究論文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf)
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP から) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi から公開された研究論文: [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870)
|
||||
|
@ -359,6 +359,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (Facebook 에서) Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli 의 [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 논문과 함께 발표했습니다.
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (WeChatAI 에서) HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 의 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 논문과 함께 발표했습니다.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (ZhuiyiTechnology 에서) Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 의 a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 논문과 함께 발표했습니다.
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (Bo Peng 에서 제공)은 Bo Peng.의 [this repo](https://github.com/BlinkDL/RWKV-LM)논문과 함께 발표했습니다.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (NVIDIA 에서) Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 의 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 논문과 함께 발표했습니다.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (Meta AI 에서 제공)은 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.의 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf)논문과 함께 발표했습니다.
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (ASAPP 에서) Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 의 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 논문과 함께 발표했습니다.
|
||||
|
@ -383,6 +383,7 @@ conda install -c huggingface transformers
|
||||
1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (来自 Facebook) 伴随论文 [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) 由 Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli 发布。
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (来自 WeChatAI), 伴随论文 [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) 由 HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou 发布。
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (来自 Bo Peng) 伴随论文 [this repo](https://github.com/BlinkDL/RWKV-LM) 由 Bo Peng 发布。
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (来自 Meta AI) 伴随论文 [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) 由 Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick 发布。
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (来自 ASAPP) 伴随论文 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 由 Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 发布。
|
||||
|
@ -395,6 +395,7 @@ conda install -c huggingface transformers
|
||||
1. **[RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli.
|
||||
1. **[RoCBert](https://huggingface.co/docs/transformers/model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou.
|
||||
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[RWKV](https://huggingface.co/docs/transformers/main/model_doc/rwkv)** (from Bo Peng) released with the paper [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng.
|
||||
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](https://huggingface.co/docs/transformers/main/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SEW](https://huggingface.co/docs/transformers/model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
|
||||
|
@ -399,6 +399,8 @@
|
||||
title: RoCBert
|
||||
- local: model_doc/roformer
|
||||
title: RoFormer
|
||||
- local: model_doc/rwkv
|
||||
title: RWKV
|
||||
- local: model_doc/splinter
|
||||
title: Splinter
|
||||
- local: model_doc/squeezebert
|
||||
|
@ -196,6 +196,7 @@ The documentation is organized into five sections:
|
||||
1. **[RoBERTa-PreLayerNorm](model_doc/roberta-prelayernorm)** (from Facebook) released with the paper [fairseq: A Fast, Extensible Toolkit for Sequence Modeling](https://arxiv.org/abs/1904.01038) by Myle Ott, Sergey Edunov, Alexei Baevski, Angela Fan, Sam Gross, Nathan Ng, David Grangier, Michael Auli.
|
||||
1. **[RoCBert](model_doc/roc_bert)** (from WeChatAI) released with the paper [RoCBert: Robust Chinese Bert with Multimodal Contrastive Pretraining](https://aclanthology.org/2022.acl-long.65.pdf) by HuiSu, WeiweiShi, XiaoyuShen, XiaoZhou, TuoJi, JiaruiFang, JieZhou.
|
||||
1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
|
||||
1. **[RWKV](model_doc/rwkv)** (from Bo Peng), released on [this repo](https://github.com/BlinkDL/RWKV-LM) by Bo Peng.
|
||||
1. **[SegFormer](model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
|
||||
1. **[Segment Anything](model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
|
||||
1. **[SEW](model_doc/sew)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
|
||||
@ -396,6 +397,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ |
|
||||
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| RWKV | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SAM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SegFormer | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
129
docs/source/en/model_doc/rwkv.mdx
Normal file
129
docs/source/en/model_doc/rwkv.mdx
Normal file
@ -0,0 +1,129 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# RWKV
|
||||
|
||||
## Overview
|
||||
|
||||
The RWKV model was proposed in [this repo](https://github.com/BlinkDL/RWKV-LM)
|
||||
|
||||
It suggests a tweak in the traditional Transformer attention to make it linear. This way, the model can be used as recurrent network: passing inputs for timestamp 0 and timestamp 1 together is the same as passing inputs at timestamp 0, then inputs at timestamp 1 along with the state of timestamp 0 (see example below).
|
||||
|
||||
This can be more efficient than a regular Transformer and can deal with sentence of any length (even if the model uses a fixed context length for training).
|
||||
|
||||
This model was contributed by [sgugger](https://huggingface.co/sgugger).
|
||||
The original code can be found [here](https://github.com/BlinkDL/RWKV-LM).
|
||||
|
||||
Example of use as an RNN:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, RwkvConfig, RwkvModel
|
||||
|
||||
model = RwkvModel.from_pretrained("sgugger/rwkv-430M-pile")
|
||||
tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-430M-pile")
|
||||
|
||||
inputs = tokenizer("This is an example.", return_tensors="pt")
|
||||
# Feed everything to the model
|
||||
outputs = model(inputs["input_ids"])
|
||||
output_whole = outputs.last_hidden_state
|
||||
|
||||
outputs = model(inputs["input_ids"][:, :2])
|
||||
output_one = outputs.last_hidden_state
|
||||
|
||||
# Using the state computed on the first inputs, we will get the same output
|
||||
outputs = model(inputs["input_ids"][:, 2:], state=outputs.state)
|
||||
output_two = outputs.last_hidden_state
|
||||
|
||||
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)
|
||||
```
|
||||
|
||||
## RwkvConfig
|
||||
|
||||
[[autodoc]] RwkvConfig
|
||||
|
||||
|
||||
## RwkvModel
|
||||
|
||||
[[autodoc]] RwkvModel
|
||||
- forward
|
||||
|
||||
## RwkvLMHeadModel
|
||||
|
||||
[[autodoc]] RwkvForCausalLM
|
||||
- forward
|
||||
|
||||
## Rwkv attention and the recurrent formulas
|
||||
|
||||
In a traditional auto-regressive Transformer, attention is written as
|
||||
|
||||
$$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$
|
||||
|
||||
with \\(Q\\), \\(K\\) and \\(V\\) are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with \\(V\\) to get the output \\(O\\) of the same shape as the others.
|
||||
|
||||
Replacing the softmax by its value gives:
|
||||
|
||||
$$O_{i} = \frac{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}} V_{j}}{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}}}$$
|
||||
|
||||
Note that the entries in \\(QK^{T}\\) corresponding to \\(j > i\\) are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones).
|
||||
|
||||
In comparison, the RWKV attention is given by
|
||||
|
||||
$$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$
|
||||
|
||||
where \\(R\\) is a new matrix called receptance by the author, \\(K\\) and \\(V\\) are still the key and value (\\(\sigma\\) here is the sigmoid function). \\(W\\) is a new vector that represents the position of the token and is given by
|
||||
|
||||
$$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$
|
||||
|
||||
with \\(u\\) and \\(w\\) learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have:
|
||||
|
||||
$$N_{i} = e^{u + K_{i}} V_{i} + \hat{N}_{i} \hbox{ where } \hat{N}_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$
|
||||
|
||||
so \\(\hat{N}_{i}\\) (called `numerator_state` in the code) satistfies
|
||||
|
||||
$$\hat{N}_{0} = 0 \hbox{ and } \hat{N}_{j+1} = e^{K_{j}} V_{j} + e^{w} \hat{N}_{j}$$
|
||||
|
||||
and
|
||||
|
||||
$$D_{i} = e^{u + K_{i}} + \hat{D}_{i} \hbox{ where } \hat{D}_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$
|
||||
|
||||
so \\(\hat{D}_{i}\\) (called `denominator_state` in the code) satistfies
|
||||
|
||||
$$\hat{D}_{0} = 0 \hbox{ and } \hat{D}_{j+1} = e^{K_{j}} + e^{w} \hat{D}_{j}$$
|
||||
|
||||
The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator:
|
||||
|
||||
$$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$
|
||||
|
||||
with \\(M\\) the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(\hat{N}\\)) and the denominator state (\\(\hat{D}\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use
|
||||
|
||||
$$\tilde{N}_{i} = e^{-M_{i}} \hat{N}_{i} \hbox{ and } \tilde{D}_{i} = e^{-M_{i}} \hat{D}_{i}$$
|
||||
|
||||
defined by the following recurrent formulas:
|
||||
|
||||
$$\tilde{N}_{0} = 0 \hbox{ and } \tilde{N}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{N}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$
|
||||
|
||||
and
|
||||
|
||||
$$\tilde{D}_{0} = 0 \hbox{ and } \tilde{D}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{D}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$
|
||||
|
||||
and \\(M_{j+1} = q\\). With those, we can then compute
|
||||
|
||||
$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{N}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$
|
||||
|
||||
and
|
||||
|
||||
$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{D}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$
|
||||
|
||||
which finally gives us
|
||||
|
||||
$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$
|
@ -33,8 +33,8 @@ You can finetune other architectures for causal language modeling following the
|
||||
Choose one of the following architectures:
|
||||
|
||||
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
|
||||
[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [Speech2Text2](../model_doc/speech_to_text_2), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod)
|
||||
|
||||
<!--End of the generated tip-->
|
||||
|
||||
|
@ -431,6 +431,7 @@ _import_structure = {
|
||||
"models.roberta_prelayernorm": ["ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaPreLayerNormConfig"],
|
||||
"models.roc_bert": ["ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoCBertConfig", "RoCBertTokenizer"],
|
||||
"models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"],
|
||||
"models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"],
|
||||
"models.sam": [
|
||||
"SAM_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"SamConfig",
|
||||
@ -2364,6 +2365,14 @@ else:
|
||||
"load_tf_weights_in_roformer",
|
||||
]
|
||||
)
|
||||
_import_structure["models.rwkv"].extend(
|
||||
[
|
||||
"RWKV_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"RwkvForCausalLM",
|
||||
"RwkvModel",
|
||||
"RwkvPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.sam"].extend(
|
||||
[
|
||||
"SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -4169,6 +4178,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.roc_bert import ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RoCBertConfig, RoCBertTokenizer
|
||||
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
|
||||
from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig
|
||||
from .models.sam import (
|
||||
SAM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
SamConfig,
|
||||
@ -5783,6 +5793,12 @@ if TYPE_CHECKING:
|
||||
RoFormerPreTrainedModel,
|
||||
load_tf_weights_in_roformer,
|
||||
)
|
||||
from .models.rwkv import (
|
||||
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RwkvForCausalLM,
|
||||
RwkvModel,
|
||||
RwkvPreTrainedModel,
|
||||
)
|
||||
from .models.sam import (
|
||||
SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
SamModel,
|
||||
|
@ -753,6 +753,8 @@ class GenerationMixin:
|
||||
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
if getattr(outputs, "state", None) is not None:
|
||||
model_kwargs["state"] = outputs.state
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
|
187
src/transformers/kernels/rwkv/wkv_cuda.cu
Normal file
187
src/transformers/kernels/rwkv/wkv_cuda.cu
Normal file
@ -0,0 +1,187 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MIN_VALUE (-1e38)
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(
|
||||
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
|
||||
const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
F aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward_with_state(
|
||||
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
|
||||
const F *__restrict__ const _k, const F *__restrict__ const _v, F *__restrict__ const _y, F *__restrict__ const _s
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset_s = _b * C * 3 + _c * 3;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
F *__restrict__ const s = _s + _offset_s;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
F aa = s[0], bb = s[1], pp = s[2];
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
s[0] = aa;
|
||||
s[1] = bb;
|
||||
s[2] = pp;
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(
|
||||
const int B, const int T, const int C, const F *__restrict__ const _w, const F *__restrict__ const _u,
|
||||
const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _y,
|
||||
const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk,
|
||||
F *__restrict__ const _gv
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
const F *__restrict__ const y = _y + _offset;
|
||||
const F *__restrict__ const gy = _gy + _offset;
|
||||
F *__restrict__ const gk = _gk + _offset;
|
||||
F *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
F q[Tmax], r[Tmax];
|
||||
|
||||
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
|
||||
F ww = u + kk;
|
||||
F p = max(pp, ww);
|
||||
F e1 = exp(pp - p);
|
||||
F e2 = exp(ww - p);
|
||||
const F qq = gy[ii] / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = gu;
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const F kk = k[ii];
|
||||
const F vv = v[ii];
|
||||
const F yy = y[ii];
|
||||
const F qq = q[i];
|
||||
const F rr = r[i];
|
||||
|
||||
F e1 = qq * exp(rr);
|
||||
F e2 = exp(kk + pp);
|
||||
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
|
||||
gv[ii] = e1 + e2 * aa;
|
||||
|
||||
const F ww = w + pp;
|
||||
const F www = rr - u - kk;
|
||||
const F p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward_with_state<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
186
src/transformers/kernels/rwkv/wkv_cuda_bf16.cu
Normal file
186
src/transformers/kernels/rwkv/wkv_cuda_bf16.cu
Normal file
@ -0,0 +1,186 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
#include "ATen/ATen.h"
|
||||
#define MIN_VALUE (-1e38)
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
__global__ void kernel_forward_bf16(
|
||||
const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
|
||||
const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
bf16 *__restrict__ const y = _y + _offset;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
float aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void kernel_forward_with_state_bf16(
|
||||
const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
|
||||
const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, bf16 *__restrict__ const _y,
|
||||
float *__restrict__ const _s
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset_s = _b * C * 3 + _c * 3;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
bf16 *__restrict__ const y = _y + _offset;
|
||||
float *__restrict__ const s = _s + _offset_s;
|
||||
|
||||
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
|
||||
float aa = s[0], bb = s[1], pp = s[2];
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
y[ii] = bf16(e1 * aa + e2 * vv) / (e1 * bb + e2);
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
s[0] = aa;
|
||||
s[1] = bb;
|
||||
s[2] = pp;
|
||||
}
|
||||
|
||||
__global__ void kernel_backward_bf16(
|
||||
const int B, const int T, const int C, const float *__restrict__ const _w, const bf16 *__restrict__ const _u,
|
||||
const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, const bf16 *__restrict__ const _y,
|
||||
const bf16 *__restrict__ const _gy, bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu,
|
||||
bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv
|
||||
) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
float u = float(_u[_c]);
|
||||
float w = _w[_c];
|
||||
const bf16 *__restrict__ const k = _k + _offset;
|
||||
const bf16 *__restrict__ const v = _v + _offset;
|
||||
const bf16 *__restrict__ const y = _y + _offset;
|
||||
const bf16 *__restrict__ const gy = _gy + _offset;
|
||||
bf16 *__restrict__ const gk = _gk + _offset;
|
||||
bf16 *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
float q[Tmax], r[Tmax];
|
||||
|
||||
float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
|
||||
float ww = u + kk;
|
||||
float p = max(pp, ww);
|
||||
float e1 = exp(pp - p);
|
||||
float e2 = exp(ww - p);
|
||||
const float qq = float(gy[ii]) / (e1 * bb + e2);
|
||||
gw += (ga - gb * yy) * e1 * qq;
|
||||
gu += (vv - yy) * e2 * qq;
|
||||
q[i] = qq;
|
||||
r[i] = ww - p;
|
||||
|
||||
ww = w + pp;
|
||||
p = max(ww, kk);
|
||||
e1 = exp(ww - p);
|
||||
e2 = exp(kk - p);
|
||||
ga = e1 * (aa + ga);
|
||||
gb = e1 * (bb + gb);
|
||||
aa = e1 * aa + e2 * vv;
|
||||
bb = e1 * bb + e2;
|
||||
pp = p;
|
||||
}
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
|
||||
_gu[_offsetBC] = bf16(gu);
|
||||
|
||||
aa = 0, bb = 0, pp = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
const float kk = float(k[ii]);
|
||||
const float vv = float(v[ii]);
|
||||
const float yy = float(y[ii]);
|
||||
const float qq = q[i];
|
||||
const float rr = r[i];
|
||||
|
||||
float e1 = qq * exp(rr);
|
||||
float e2 = exp(kk + pp);
|
||||
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
|
||||
gv[ii] = bf16(e1 + e2 * aa);
|
||||
|
||||
const float ww = w + pp;
|
||||
const float www = rr - u - kk;
|
||||
const float p = max(ww, www);
|
||||
e1 = exp(ww - p);
|
||||
e2 = qq * exp(www - p);
|
||||
aa = e1 * aa + e2;
|
||||
bb = e1 * bb - e2 * yy;
|
||||
pp = p;
|
||||
}
|
||||
}
|
||||
|
||||
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward_with_state_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, s);
|
||||
}
|
||||
|
||||
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward_bf16<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
|
||||
}
|
66
src/transformers/kernels/rwkv/wkv_op.cpp
Normal file
66
src/transformers/kernels/rwkv/wkv_op.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/ATen.h"
|
||||
typedef at::BFloat16 bf16;
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||
void cuda_forward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
|
||||
void cuda_forward_with_state(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *s);
|
||||
void cuda_forward_with_state_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, float *s);
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||
void cuda_backward_bf16(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
|
||||
|
||||
void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
const int B = k.size(0);
|
||||
const int T = k.size(1);
|
||||
const int C = k.size(2);
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||
}
|
||||
void forward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
const int B = k.size(0);
|
||||
const int T = k.size(1);
|
||||
const int C = k.size(2);
|
||||
cuda_forward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
|
||||
}
|
||||
void forward_with_state(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
|
||||
const int B = k.size(0);
|
||||
const int T = k.size(1);
|
||||
const int C = k.size(2);
|
||||
cuda_forward_with_state(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), s.data_ptr<float>());
|
||||
}
|
||||
void forward_with_state_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &s) {
|
||||
const int B = k.size(0);
|
||||
const int T = k.size(1);
|
||||
const int C = k.size(2);
|
||||
cuda_forward_with_state_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), s.data_ptr<float>());
|
||||
}
|
||||
void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
const int B = k.size(0);
|
||||
const int T = k.size(1);
|
||||
const int C = k.size(2);
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||
}
|
||||
void backward_bf16(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
const int B = k.size(0);
|
||||
const int T = k.size(1);
|
||||
const int C = k.size(2);
|
||||
cuda_backward_bf16(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
|
||||
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("forward_bf16", &forward_bf16, "wkv forward bf16");
|
||||
m.def("forward_with_state", &forward_with_state, "wkv forward with state");
|
||||
m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv forward with state bf16");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
m.def("backward_bf16", &backward_bf16, "wkv backward bf16");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("forward_bf16", forward_bf16);
|
||||
m.def("forward_with_state", forward_with_state);
|
||||
m.def("forward_with_state_bf16", forward_with_state_bf16);
|
||||
m.def("backward", backward);
|
||||
m.def("backward_bf16", backward_bf16);
|
||||
}
|
@ -162,6 +162,7 @@ from . import (
|
||||
roberta_prelayernorm,
|
||||
roc_bert,
|
||||
roformer,
|
||||
rwkv,
|
||||
sam,
|
||||
segformer,
|
||||
sew,
|
||||
|
@ -163,6 +163,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormConfig"),
|
||||
("roc_bert", "RoCBertConfig"),
|
||||
("roformer", "RoFormerConfig"),
|
||||
("rwkv", "RwkvConfig"),
|
||||
("sam", "SamConfig"),
|
||||
("segformer", "SegformerConfig"),
|
||||
("sew", "SEWConfig"),
|
||||
@ -343,6 +344,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "ROBERTA_PRELAYERNORM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roc_bert", "ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("sew", "SEW_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -545,6 +547,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("roberta-prelayernorm", "RoBERTa-PreLayerNorm"),
|
||||
("roc_bert", "RoCBert"),
|
||||
("roformer", "RoFormer"),
|
||||
("rwkv", "RWKV"),
|
||||
("sam", "SAM"),
|
||||
("segformer", "SegFormer"),
|
||||
("sew", "SEW"),
|
||||
|
@ -158,6 +158,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
|
||||
("roc_bert", "RoCBertModel"),
|
||||
("roformer", "RoFormerModel"),
|
||||
("rwkv", "RwkvModel"),
|
||||
("sam", "SamModel"),
|
||||
("segformer", "SegformerModel"),
|
||||
("sew", "SEWModel"),
|
||||
@ -248,6 +249,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
("roberta", "RobertaForMaskedLM"),
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
|
||||
("roc_bert", "RoCBertForPreTraining"),
|
||||
("rwkv", "RwkvForCausalLM"),
|
||||
("splinter", "SplinterForPreTraining"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
||||
@ -332,6 +334,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
|
||||
("roc_bert", "RoCBertForMaskedLM"),
|
||||
("roformer", "RoFormerForMaskedLM"),
|
||||
("rwkv", "RwkvForCausalLM"),
|
||||
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
||||
("squeezebert", "SqueezeBertForMaskedLM"),
|
||||
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
||||
@ -395,6 +398,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
|
||||
("roc_bert", "RoCBertForCausalLM"),
|
||||
("roformer", "RoFormerForCausalLM"),
|
||||
("rwkv", "RwkvForCausalLM"),
|
||||
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
||||
("transfo-xl", "TransfoXLLMHeadModel"),
|
||||
("trocr", "TrOCRForCausalLM"),
|
||||
|
@ -297,6 +297,7 @@ else:
|
||||
),
|
||||
("roc_bert", ("RoCBertTokenizer", None)),
|
||||
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
|
||||
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
|
||||
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
|
||||
|
60
src/transformers/models/rwkv/__init__.py
Normal file
60
src/transformers/models/rwkv/__init__.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig", "RwkvOnnxConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_rwkv"] = [
|
||||
"RWKV_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"RwkvForCausalLM",
|
||||
"RwkvModel",
|
||||
"RwkvPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig, RwkvOnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_rwkv import (
|
||||
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RwkvForCausalLM,
|
||||
RwkvModel,
|
||||
RwkvPreTrainedModel,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
130
src/transformers/models/rwkv/configuration_rwkv.py
Normal file
130
src/transformers/models/rwkv/configuration_rwkv.py
Normal file
@ -0,0 +1,130 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" RWKV configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"RWKV/rwkv-4-169m-pile": "https://huggingface.co/RWKV/rwkv-4-169m-pile/resolve/main/config.json",
|
||||
"RWKV/rwkv-4-430m-pile": "https://huggingface.co/RWKV/rwkv-4-430m-pile/resolve/main/config.json",
|
||||
"RWKV/rwkv-4-1b5-pile": "https://huggingface.co/RWKV/rwkv-4-1b5-pile/resolve/main/config.json",
|
||||
"RWKV/rwkv-4-3b-pile": "https://huggingface.co/RWKV/rwkv-4-3b-pile/resolve/main/config.json",
|
||||
"RWKV/rwkv-4-7b-pile": "https://huggingface.co/RWKV/rwkv-4-7b-pile/resolve/main/config.json",
|
||||
"RWKV/rwkv-4-14b-pile": "https://huggingface.co/RWKV/rwkv-4-14b-pile/resolve/main/config.json",
|
||||
"RWKV/rwkv-raven-1b5": "https://huggingface.co/RWKV/rwkv-raven-1b5/resolve/main/config.json",
|
||||
"RWKV/rwkv-raven-3b": "https://huggingface.co/RWKV/rwkv-raven-3b/resolve/main/config.json",
|
||||
"RWKV/rwkv-raven-7b": "https://huggingface.co/RWKV/rwkv-raven-7b/resolve/main/config.json",
|
||||
"RWKV/rwkv-raven-14b": "https://huggingface.co/RWKV/rwkv-raven-14b/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class RwkvConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`RwkvModel`]. It is used to instantiate a RWKV
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the RWVK-4
|
||||
[RWKV/rwkv-4-169m-pile](https://huggingface.co/RWKV/rwkv-4-169m-pile) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50277):
|
||||
Vocabulary size of the RWKV model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`RwkvModel`].
|
||||
context_length (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model can be be used with in a single forward (using it in RNN mode
|
||||
lets use any sequence length).
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the model.
|
||||
attention_hidden_size (`int`, *optional*):
|
||||
Dimensionality of the attention hidden states. Will default to `hidden_size` if unset.
|
||||
intermediate_size (`int`, *optional*):
|
||||
Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
bos_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the beginning of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer
|
||||
as GPTNeoX.
|
||||
eos_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the end of sentence token in the vocabulary. Defaults to 0 as RWKV uses the same tokenizer as
|
||||
GPTNeoX.
|
||||
rescale_every (`int`, *optional*, default to 6):
|
||||
At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every
|
||||
`rescale_every` layer. If set to 0 or a negative number, no rescale is done.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to tie the word embeddings with the input token embeddings.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last state.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import RwkvConfig, RwkvModel
|
||||
|
||||
>>> # Initializing a Rwkv configuration
|
||||
>>> configuration = RwkvConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = RwkvModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "rwkv"
|
||||
attribute_map = {"max_position_embeddings": "context_length"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50277,
|
||||
context_length=1024,
|
||||
hidden_size=4096,
|
||||
num_hidden_layers=32,
|
||||
attention_hidden_size=None,
|
||||
intermediate_size=None,
|
||||
layer_norm_epsilon=1e-5,
|
||||
bos_token_id=0,
|
||||
eos_token_id=0,
|
||||
rescale_every=6,
|
||||
tie_word_embeddings=False,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.context_length = context_length
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size
|
||||
self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.rescale_every = rescale_every
|
||||
self.use_cache = use_cache
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs
|
||||
)
|
201
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
Normal file
201
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
Normal file
@ -0,0 +1,201 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert a RWKV checkpoint from BlinkDL to the Hugging Face format."""
|
||||
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig
|
||||
from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint
|
||||
|
||||
|
||||
NUM_HIDDEN_LAYERS_MAPPING = {
|
||||
"169M": 12,
|
||||
"430M": 24,
|
||||
"1B5": 24,
|
||||
"3B": 32,
|
||||
"7B": 32,
|
||||
"14B": 40,
|
||||
}
|
||||
|
||||
HIDEN_SIZE_MAPPING = {
|
||||
"169M": 768,
|
||||
"430M": 1024,
|
||||
"1B5": 2048,
|
||||
"3B": 2560,
|
||||
"7B": 4096,
|
||||
"14B": 5120,
|
||||
}
|
||||
|
||||
|
||||
def convert_state_dict(state_dict):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
for name in state_dict_keys:
|
||||
weight = state_dict.pop(name)
|
||||
# emb -> embedding
|
||||
if name.startswith("emb."):
|
||||
name = name.replace("emb.", "embeddings.")
|
||||
# ln_0 -> pre_ln (only present at block 0)
|
||||
if name.startswith("blocks.0.ln0"):
|
||||
name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
|
||||
# att -> attention
|
||||
name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
|
||||
# ffn -> feed_forward
|
||||
name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
|
||||
# time_mix_k -> time_mix_key and reshape
|
||||
if name.endswith(".time_mix_k"):
|
||||
name = name.replace(".time_mix_k", ".time_mix_key")
|
||||
# time_mix_v -> time_mix_value and reshape
|
||||
if name.endswith(".time_mix_v"):
|
||||
name = name.replace(".time_mix_v", ".time_mix_value")
|
||||
# time_mix_r -> time_mix_key and reshape
|
||||
if name.endswith(".time_mix_r"):
|
||||
name = name.replace(".time_mix_r", ".time_mix_receptance")
|
||||
|
||||
if name != "head.weight":
|
||||
name = "rwkv." + name
|
||||
|
||||
state_dict[name] = weight
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_rmkv_checkpoint_to_hf_format(
|
||||
repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None
|
||||
):
|
||||
# 1. If possible, build the tokenizer.
|
||||
if tokenizer_file is None:
|
||||
print("No `--tokenizer_file` provided, we will use the default tokenizer.")
|
||||
vocab_size = 50277
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
else:
|
||||
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
|
||||
vocab_size = len(tokenizer)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# 2. Build the config
|
||||
possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys())
|
||||
if size is None:
|
||||
# Try to infer size from the checkpoint name
|
||||
for candidate in possible_sizes:
|
||||
if candidate in checkpoint_file:
|
||||
size = candidate
|
||||
break
|
||||
if size is None:
|
||||
raise ValueError("Could not infer the size, please provide it with the `--size` argument.")
|
||||
if size not in possible_sizes:
|
||||
raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.")
|
||||
|
||||
config = RwkvConfig(
|
||||
vocab_size=vocab_size,
|
||||
num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size],
|
||||
hidden_size=HIDEN_SIZE_MAPPING[size],
|
||||
)
|
||||
config.save_pretrained(output_dir)
|
||||
|
||||
# 3. Download model file then convert state_dict
|
||||
model_file = hf_hub_download(repo_id, checkpoint_file)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
state_dict = convert_state_dict(state_dict)
|
||||
|
||||
# 4. Split in shards and save
|
||||
shards, index = shard_checkpoint(state_dict)
|
||||
for shard_file, shard in shards.items():
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is not None:
|
||||
save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
# 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict
|
||||
print(
|
||||
"Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model."
|
||||
)
|
||||
shard_files = list(shards.keys())
|
||||
|
||||
del state_dict
|
||||
del shards
|
||||
gc.collect()
|
||||
|
||||
for shard_file in shard_files:
|
||||
state_dict = torch.load(os.path.join(output_dir, shard_file))
|
||||
torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file))
|
||||
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if push_to_hub:
|
||||
if model_name is None:
|
||||
raise ValueError("Please provide a `model_name` to push the model to the Hub.")
|
||||
model = AutoModelForCausalLM.from_pretrained(output_dir)
|
||||
model.push_to_hub(model_name, max_shard_size="2GB")
|
||||
tokenizer.push_to_hub(model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", default=None, type=str, required=True, help="Where to save the converted model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_file",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Path to the tokenizer file to use (if not provided, only the model is converted).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Push to the Hub the converted model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Name of the pushed model on the Hub, including the username / organization.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_rmkv_checkpoint_to_hf_format(
|
||||
args.repo_id,
|
||||
args.checkpoint_file,
|
||||
args.output_dir,
|
||||
size=args.size,
|
||||
tokenizer_file=args.tokenizer_file,
|
||||
push_to_hub=args.push_to_hub,
|
||||
model_name=args.model_name,
|
||||
)
|
804
src/transformers/models/rwkv/modeling_rwkv.py
Normal file
804
src/transformers/models/rwkv/modeling_rwkv.py
Normal file
@ -0,0 +1,804 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Bo Peng and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch RWKV model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_ninja_available,
|
||||
is_torch_cuda_available,
|
||||
logging,
|
||||
)
|
||||
from .configuration_rwkv import RwkvConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "RWKV/rwkv-4-169m-pile"
|
||||
_CONFIG_FOR_DOC = "RwkvConfig"
|
||||
|
||||
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"RWKV/rwkv-4-169m-pile",
|
||||
"RWKV/rwkv-4-430m-pile",
|
||||
"RWKV/rwkv-4-1b5-pile",
|
||||
"RWKV/rwkv-4-3b-pile",
|
||||
"RWKV/rwkv-4-7b-pile",
|
||||
"RWKV/rwkv-4-14b-pile",
|
||||
"RWKV/rwkv-raven-1b5",
|
||||
"RWKV/rwkv-raven-3b",
|
||||
"RWKV/rwkv-raven-7b",
|
||||
"RWKV/rwkv-raven-14b",
|
||||
# See all RWKV models at https://huggingface.co/models?filter=rwkv
|
||||
]
|
||||
|
||||
|
||||
rwkv_cuda_kernel = None
|
||||
|
||||
|
||||
def load_wkv_cuda_kernel(context_length):
|
||||
from torch.utils.cpp_extension import load as load_kernel
|
||||
|
||||
global rwkv_cuda_kernel
|
||||
|
||||
kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "rwkv"
|
||||
cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]]
|
||||
|
||||
# Only load the kernel if it's not been loaded yet or if we changed the context length
|
||||
if rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == context_length:
|
||||
return
|
||||
|
||||
logger.info(f"Loading CUDA kernel for RWKV at context length of {context_length}.")
|
||||
|
||||
flags = [
|
||||
"-res-usage",
|
||||
"--maxrregcount 60",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
"-Xptxas -O3",
|
||||
"--extra-device-vectorization",
|
||||
f"-DTmax={context_length}",
|
||||
]
|
||||
rwkv_cuda_kernel = load_kernel(
|
||||
name=f"wkv_{context_length}",
|
||||
sources=cuda_kernel_files,
|
||||
verbose=(logging.get_verbosity() == logging.DEBUG),
|
||||
extra_cuda_cflags=flags,
|
||||
)
|
||||
rwkv_cuda_kernel.max_seq_length = context_length
|
||||
|
||||
|
||||
class RwkvLinearAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False):
|
||||
batch_size, seq_len, hidden_size = key.size()
|
||||
if seq_len > rwkv_cuda_kernel.max_seq_length:
|
||||
raise ValueError(
|
||||
f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of "
|
||||
f"{rwkv_cuda_kernel.max_seq_length} with this model."
|
||||
)
|
||||
if batch_size * hidden_size % min(hidden_size, 32) != 0:
|
||||
raise ValueError(
|
||||
f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round "
|
||||
f"multiple of {min(hidden_size, 32)}."
|
||||
)
|
||||
|
||||
ctx.input_dtype = key.dtype
|
||||
|
||||
if (
|
||||
time_decay.device.type != "cuda"
|
||||
or time_first.device.type != "cuda"
|
||||
or key.device.type != "cuda"
|
||||
or value.device.type != "cuda"
|
||||
):
|
||||
raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.")
|
||||
|
||||
time_decay = -torch.exp(time_decay.float().contiguous())
|
||||
if key.dtype == torch.float16:
|
||||
time_first = time_first.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
time_first = time_first.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
# The CUDA kernel will fill this tensor.
|
||||
output = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
if return_state or state is not None:
|
||||
if state is None:
|
||||
state = torch.zeros(
|
||||
batch_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dtype=torch.float32,
|
||||
device=key.device,
|
||||
memory_format=torch.contiguous_format,
|
||||
)
|
||||
state[:, :, 2] -= 1e38
|
||||
else:
|
||||
state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous()
|
||||
if key.dtype == torch.bfloat16:
|
||||
forward_func = rwkv_cuda_kernel.forward_with_state_bf16
|
||||
else:
|
||||
forward_func = rwkv_cuda_kernel.forward_with_state
|
||||
forward_func(time_decay, time_first, key, value, output, state)
|
||||
else:
|
||||
forward_func = rwkv_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else rwkv_cuda_kernel.forward
|
||||
forward_func(time_decay, time_first, key, value, output)
|
||||
|
||||
ctx.save_for_backward(time_decay, time_first, key, value, output)
|
||||
|
||||
if state is not None:
|
||||
state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)]
|
||||
|
||||
return output.to(ctx.input_dtype), state
|
||||
|
||||
@staticmethod
|
||||
# g stands for grad
|
||||
def backward(ctx, g_output):
|
||||
input_dtype = ctx.input_dtype
|
||||
|
||||
time_decay, time_first, key, value, output = ctx.saved_tensors
|
||||
# The CUDA kernel will fill those tensors.
|
||||
g_time_decay = torch.empty_like(
|
||||
time_decay,
|
||||
memory_format=torch.contiguous_format,
|
||||
dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32,
|
||||
)
|
||||
g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format)
|
||||
g_key = torch.empty_like(key, memory_format=torch.contiguous_format)
|
||||
g_value = torch.empty_like(value, memory_format=torch.contiguous_format)
|
||||
|
||||
if input_dtype == torch.float16:
|
||||
g_output = g_output.float()
|
||||
backward_func = rwkv_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else rwkv_cuda_kernel.backward
|
||||
backward_func(
|
||||
time_decay,
|
||||
time_first,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
g_output.contiguous(),
|
||||
g_time_decay,
|
||||
g_time_first,
|
||||
g_key,
|
||||
g_value,
|
||||
)
|
||||
g_time_decay = torch.sum(g_time_decay, dim=0)
|
||||
g_time_first = torch.sum(g_time_first, dim=0)
|
||||
|
||||
return (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
g_time_decay.to(input_dtype),
|
||||
g_time_first.to(input_dtype),
|
||||
g_key.to(input_dtype),
|
||||
g_value.to(input_dtype),
|
||||
)
|
||||
|
||||
|
||||
def rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False):
|
||||
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
|
||||
# within a torch.no_grad.
|
||||
_, seq_length, _ = key.size()
|
||||
output = torch.zeros_like(key)
|
||||
|
||||
if state is None:
|
||||
num_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
||||
den_state = torch.zeros_like(key[:, 0], dtype=torch.float32)
|
||||
max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38
|
||||
else:
|
||||
num_state, den_state, max_state = state
|
||||
# For numerical stability
|
||||
# real_numerator_state = num_state * torch.exp(max_state)
|
||||
# real_denominator_state = den_state * torch.exp(max_state)
|
||||
|
||||
time_decay = -torch.exp(time_decay)
|
||||
|
||||
for current_index in range(seq_length):
|
||||
current_key = key[:, current_index].float()
|
||||
current_value = value[:, current_index]
|
||||
|
||||
# wkv computation at time t
|
||||
max_for_output = torch.maximum(max_state, current_key + time_first)
|
||||
e1 = torch.exp(max_state - max_for_output)
|
||||
e2 = torch.exp(current_key + time_first - max_for_output)
|
||||
numerator = e1 * num_state + e2 * current_value
|
||||
denominator = e1 * den_state + e2
|
||||
output[:, current_index] = (numerator / denominator).to(output.dtype)
|
||||
|
||||
# Update state for next iteration
|
||||
max_for_state = torch.maximum(max_state + time_decay, current_key)
|
||||
e1 = torch.exp(max_state + time_decay - max_for_state)
|
||||
e2 = torch.exp(current_key - max_for_state)
|
||||
num_state = e1 * num_state + e2 * current_value
|
||||
den_state = e1 * den_state + e2
|
||||
max_state = max_for_state
|
||||
|
||||
if return_state or state is not None:
|
||||
state = [num_state, den_state, max_state]
|
||||
|
||||
return output, state
|
||||
|
||||
|
||||
def rwkv_linear_attention(time_decay, time_first, key, value, state=None, return_state=False):
|
||||
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value])
|
||||
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
|
||||
# in this case).
|
||||
one_token = key.size(1) == 1
|
||||
if rwkv_cuda_kernel is None or no_cuda or one_token:
|
||||
return rwkv_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state)
|
||||
else:
|
||||
return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state)
|
||||
|
||||
|
||||
class RwkvSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
kernel_loaded = rwkv_cuda_kernel is not None and rwkv_cuda_kernel.max_seq_length == config.context_length
|
||||
if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded:
|
||||
try:
|
||||
load_wkv_cuda_kernel(config.context_length)
|
||||
except Exception:
|
||||
logger.info("Could not load the custom CUDA kernel for RWKV attention.")
|
||||
self.layer_id = layer_id
|
||||
hidden_size = config.hidden_size
|
||||
attention_hidden_size = (
|
||||
config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size
|
||||
)
|
||||
self.attention_hidden_size = attention_hidden_size
|
||||
|
||||
self.time_decay = nn.Parameter(torch.empty(attention_hidden_size))
|
||||
self.time_first = nn.Parameter(torch.empty(attention_hidden_size))
|
||||
|
||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
|
||||
self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size))
|
||||
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
||||
self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
||||
self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False)
|
||||
self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False)
|
||||
|
||||
# TODO: maybe jit, otherwise move inside forward
|
||||
def extract_key_value(self, hidden, state=None):
|
||||
# Mix hidden with the previous timestep to produce key, value, receptance
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[1][:, :, self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[1][:, :, self.layer_id]
|
||||
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
||||
value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
|
||||
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
||||
|
||||
key = self.key(key)
|
||||
value = self.value(value)
|
||||
receptance = torch.sigmoid(self.receptance(receptance))
|
||||
if state is not None:
|
||||
state[1][:, :, self.layer_id] = hidden[:, -1]
|
||||
return receptance, key, value, state
|
||||
|
||||
def forward(self, hidden, state=None, use_cache=False):
|
||||
receptance, key, value, state = self.extract_key_value(hidden, state=state)
|
||||
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
||||
rwkv, layer_state = rwkv_linear_attention(
|
||||
self.time_decay,
|
||||
self.time_first,
|
||||
key,
|
||||
value,
|
||||
state=layer_state,
|
||||
return_state=use_cache,
|
||||
)
|
||||
|
||||
if layer_state is not None:
|
||||
state[2][:, :, self.layer_id] = layer_state[0]
|
||||
state[3][:, :, self.layer_id] = layer_state[1]
|
||||
state[4][:, :, self.layer_id] = layer_state[2]
|
||||
|
||||
return self.output(receptance * rwkv), state
|
||||
|
||||
|
||||
class RwkvFeedForward(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = (
|
||||
config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size
|
||||
)
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size))
|
||||
self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size))
|
||||
|
||||
self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.receptance = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
|
||||
def forward(self, hidden, state=None):
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[0][:, :, self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[0][:, :, self.layer_id]
|
||||
key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
|
||||
receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
|
||||
|
||||
key = torch.square(torch.relu(self.key(key)))
|
||||
value = self.value(key)
|
||||
receptance = torch.sigmoid(self.receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[0][:, :, self.layer_id] = hidden[:, -1]
|
||||
|
||||
return receptance * value, state
|
||||
|
||||
|
||||
class RwkvBlock(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
if layer_id == 0:
|
||||
self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.attention = RwkvSelfAttention(config, layer_id)
|
||||
self.feed_forward = RwkvFeedForward(config, layer_id)
|
||||
|
||||
def forward(self, hidden, state=None, use_cache=False, output_attentions=False):
|
||||
if self.layer_id == 0:
|
||||
hidden = self.pre_ln(hidden)
|
||||
|
||||
attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
|
||||
hidden = hidden + attention
|
||||
|
||||
feed_forward, state = self.feed_forward(self.ln2(hidden), state=state)
|
||||
hidden = hidden + feed_forward
|
||||
|
||||
outputs = (hidden, state)
|
||||
if output_attentions:
|
||||
outputs += (attention,)
|
||||
else:
|
||||
outputs += (None,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class RwkvPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = RwkvConfig
|
||||
base_model_prefix = "rwkv"
|
||||
_no_split_modules = ["RwkvBlock"]
|
||||
_keep_in_fp32_modules = ["time_decay", "time_first"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, RwkvSelfAttention):
|
||||
layer_id = module.layer_id
|
||||
num_hidden_layers = module.config.num_hidden_layers
|
||||
hidden_size = module.config.hidden_size
|
||||
attention_hidden_size = module.attention_hidden_size
|
||||
|
||||
ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
||||
|
||||
time_weight = torch.tensor(
|
||||
[i / hidden_size for i in range(hidden_size)],
|
||||
dtype=module.time_mix_key.dtype,
|
||||
device=module.time_mix_key.device,
|
||||
)
|
||||
time_weight = time_weight[None, None, :]
|
||||
|
||||
decay_speed = [
|
||||
-5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
for h in range(attention_hidden_size)
|
||||
]
|
||||
decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device)
|
||||
zigzag = (
|
||||
torch.tensor(
|
||||
[(i + 1) % 3 - 1 for i in range(attention_hidden_size)],
|
||||
dtype=module.time_first.dtype,
|
||||
device=module.time_first.device,
|
||||
)
|
||||
* 0.5
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
module.time_decay.data = decay_speed
|
||||
module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag)
|
||||
|
||||
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
||||
module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
|
||||
module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
|
||||
elif isinstance(module, RwkvFeedForward):
|
||||
layer_id = module.layer_id
|
||||
num_hidden_layers = module.config.num_hidden_layers
|
||||
hidden_size = module.config.hidden_size
|
||||
|
||||
ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0
|
||||
|
||||
time_weight = torch.tensor(
|
||||
[i / hidden_size for i in range(hidden_size)],
|
||||
dtype=module.time_mix_key.dtype,
|
||||
device=module.time_mix_key.device,
|
||||
)
|
||||
time_weight = time_weight[None, None, :]
|
||||
|
||||
with torch.no_grad():
|
||||
module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
|
||||
module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, RwkvModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class RwkvOutput(ModelOutput):
|
||||
"""
|
||||
Class for the RWKV model outputs.
|
||||
|
||||
Args:
|
||||
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
state: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RwkvCausalLMOutput(ModelOutput):
|
||||
"""
|
||||
Base class for causal language model (or autoregressive) outputs.
|
||||
|
||||
Args:
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
||||
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
||||
avoid providing the old `input_ids`.
|
||||
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||||
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||||
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||||
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||||
sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
state: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
RWKV_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`RwkvConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
RWKV_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
||||
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
|
||||
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
||||
`input_ids`.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
|
||||
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
||||
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, the last state is returned and can be used to quickly generate the next logits.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RWKV Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
RWKV_START_DOCSTRING,
|
||||
)
|
||||
class RwkvModel(RwkvPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.blocks = nn.ModuleList([RwkvBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)])
|
||||
self.ln_out = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
self.layers_are_rescaled = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.embeddings = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=RwkvOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
state: Optional[List[torch.FloatTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, RwkvOutput]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.training == self.layers_are_rescaled:
|
||||
self._rescale_layers()
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is None and inputs_embeds is None:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
|
||||
if use_cache and state is None:
|
||||
shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers)
|
||||
state = [
|
||||
torch.zeros(
|
||||
*shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
state[4] -= 1e30
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for idx, block in enumerate(self.blocks):
|
||||
hidden_states, state, attentions = block(
|
||||
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
|
||||
)
|
||||
if (
|
||||
self.layers_are_rescaled
|
||||
and self.config.rescale_every > 0
|
||||
and (idx + 1) % self.config.rescale_every == 0
|
||||
):
|
||||
hidden_states = hidden_states / 2
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attentions,)
|
||||
|
||||
hidden_states = self.ln_out(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states, state, all_hidden_states, all_self_attentions)
|
||||
|
||||
return RwkvOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
state=state,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
def _rescale_layers(self):
|
||||
# Layers should be rescaled for inference only.
|
||||
if self.layers_are_rescaled == (not self.training):
|
||||
return
|
||||
if self.config.rescale_every > 0:
|
||||
with torch.no_grad():
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
if self.training:
|
||||
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
||||
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
|
||||
else:
|
||||
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
|
||||
|
||||
self.layers_are_rescaled = not self.training
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
RWKV_START_DOCSTRING,
|
||||
)
|
||||
class RwkvForCausalLM(RwkvPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.rwkv = RwkvModel(config)
|
||||
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
|
||||
# only last token for inputs_ids if the state is passed along.
|
||||
if state is not None:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and state is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs["state"] = state
|
||||
return model_inputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(RWKV_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=RwkvCausalLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
state: Optional[List[torch.FloatTensor]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, RwkvCausalLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
rwkv_outputs = self.rwkv(
|
||||
input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
state=state,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = rwkv_outputs[0]
|
||||
|
||||
logits = self.head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + rwkv_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return RwkvCausalLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
state=rwkv_outputs.state,
|
||||
hidden_states=rwkv_outputs.hidden_states,
|
||||
attentions=rwkv_outputs.attentions,
|
||||
)
|
@ -6105,6 +6105,30 @@ def load_tf_weights_in_roformer(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_roformer, ["torch"])
|
||||
|
||||
|
||||
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class RwkvForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RwkvModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RwkvPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
SAM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
0
tests/models/rwkv/__init__.py
Normal file
0
tests/models/rwkv/__init__.py
Normal file
451
tests/models/rwkv/test_modeling_rwkv.py
Normal file
451
tests/models/rwkv/test_modeling_rwkv.py
Normal file
@ -0,0 +1,451 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
from unittest.util import safe_repr
|
||||
|
||||
from transformers import AutoTokenizer, RwkvConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RwkvForCausalLM,
|
||||
RwkvModel,
|
||||
)
|
||||
|
||||
|
||||
class RwkvModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_token_type_ids=False,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
use_mc_token_ids=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.use_mc_token_ids = use_mc_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
self.pad_token_id = vocab_size - 1
|
||||
|
||||
def get_large_model_config(self):
|
||||
return RwkvConfig.from_pretrained("sgugger/rwkv-4-pile-7b")
|
||||
|
||||
def prepare_config_and_inputs(
|
||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||
):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
mc_token_ids = None
|
||||
if self.use_mc_token_ids:
|
||||
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config(
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
|
||||
reorder_and_upcast_attn=reorder_and_upcast_attn,
|
||||
)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
None,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def get_config(
|
||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||
):
|
||||
return RwkvConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
intermediate_size=self.intermediate_size,
|
||||
activation_function=self.hidden_act,
|
||||
resid_pdrop=self.hidden_dropout_prob,
|
||||
attn_pdrop=self.attention_probs_dropout_prob,
|
||||
n_positions=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
use_cache=True,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
|
||||
reorder_and_upcast_attn=reorder_and_upcast_attn,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 300
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
config.output_hidden_states = True
|
||||
model = RwkvModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
|
||||
|
||||
def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = RwkvForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
result = model(input_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = RwkvModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids)
|
||||
output_whole = outputs.last_hidden_state
|
||||
|
||||
outputs = model(input_ids[:, :2])
|
||||
output_one = outputs.last_hidden_state
|
||||
|
||||
# Using the state computed on the first inputs, we will get the same output
|
||||
outputs = model(input_ids[:, 2:], state=outputs.state)
|
||||
output_two = outputs.last_hidden_state
|
||||
|
||||
self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
|
||||
|
||||
def create_and_check_forward_and_backwards(
|
||||
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
||||
):
|
||||
model = RwkvForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
result = model(input_ids, labels=input_ids)
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
result.loss.backward()
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
mc_token_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {"input_ids": input_ids}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": RwkvModel,
|
||||
"text-generation": RwkvForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
|
||||
fx_compatible = False
|
||||
test_missing_keys = False
|
||||
test_model_parallel = False
|
||||
test_pruning = False
|
||||
test_head_masking = False # Rwkv does not support head masking
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RwkvModelTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||
)
|
||||
|
||||
def assertInterval(self, member, container, msg=None):
|
||||
r"""
|
||||
Simple utility function to check if a member is inside an interval.
|
||||
"""
|
||||
if isinstance(member, torch.Tensor):
|
||||
max_value, min_value = member.max().item(), member.min().item()
|
||||
elif isinstance(member, list) or isinstance(member, tuple):
|
||||
max_value, min_value = max(member), min(member)
|
||||
|
||||
if not isinstance(container, list):
|
||||
raise TypeError("container should be a list or tuple")
|
||||
elif len(container) != 2:
|
||||
raise ValueError("container should have 2 elements")
|
||||
|
||||
expected_min, expected_max = container
|
||||
|
||||
is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max)
|
||||
|
||||
if not is_inside_interval:
|
||||
standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container))
|
||||
self.fail(self._formatMessage(msg, standardMsg))
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_rwkv_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_rwkv_model(*config_and_inputs)
|
||||
|
||||
def test_rwkv_lm_head_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_causl_lm(*config_and_inputs)
|
||||
|
||||
def test_state_equivalency(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
for name, param in model.named_parameters():
|
||||
if "time_decay" in name:
|
||||
if param.requires_grad:
|
||||
self.assertTrue(param.data.max().item() == 3.0)
|
||||
self.assertTrue(param.data.min().item() == -5.0)
|
||||
elif "time_first" in name:
|
||||
if param.requires_grad:
|
||||
# check if it's a ones like
|
||||
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
|
||||
elif any([x in name for x in ["time_mix_key", "time_mix_receptance"]]):
|
||||
if param.requires_grad:
|
||||
self.assertInterval(
|
||||
param.data,
|
||||
[0.0, 1.0],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
elif "time_mix_value" in name:
|
||||
if param.requires_grad:
|
||||
self.assertInterval(
|
||||
param.data,
|
||||
[0.0, 1.3],
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
r"""
|
||||
Overriding the test_attention_outputs test as the attention outputs of Rwkv are different from other models
|
||||
it has a shape `batch_size, seq_len, hidden_size`.
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
attentions = outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[batch_size, seq_len, config.hidden_size],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
batch_size = inputs["input_ids"].shape[0]
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
added_hidden_states = 1
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[batch_size, seq_len, config.hidden_size],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = RwkvModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@slow
|
||||
class RWKVIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "RWKV/rwkv-4-169m-pile"
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
|
||||
def test_simple_generate(self):
|
||||
expected_output = "Hello my name is Jasmine and I am a newbie to the"
|
||||
model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device)
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
|
||||
output = model.generate(input_ids, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
|
||||
self.assertEqual(output_sentence, expected_output)
|
||||
|
||||
def test_simple_generate_bf16(self):
|
||||
expected_output = "Hello my name is Jasmine and I am a newbie to the"
|
||||
|
||||
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
|
||||
model = RwkvForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
|
||||
|
||||
output = model.generate(input_ids, max_new_tokens=10)
|
||||
output_sentence = self.tokenizer.decode(output[0].tolist())
|
||||
|
||||
self.assertEqual(output_sentence, expected_output)
|
@ -93,15 +93,20 @@ config_common_kwargs = {
|
||||
|
||||
|
||||
class ConfigTester(object):
|
||||
def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
|
||||
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
|
||||
self.parent = parent
|
||||
self.config_class = config_class
|
||||
self.has_text_modality = has_text_modality
|
||||
self.inputs_dict = kwargs
|
||||
self.common_properties = common_properties
|
||||
|
||||
def create_and_test_config_common_properties(self):
|
||||
config = self.config_class(**self.inputs_dict)
|
||||
common_properties = ["hidden_size", "num_attention_heads", "num_hidden_layers"]
|
||||
common_properties = (
|
||||
["hidden_size", "num_attention_heads", "num_hidden_layers"]
|
||||
if self.common_properties is None
|
||||
else self.common_properties
|
||||
)
|
||||
|
||||
# Add common fields for text models
|
||||
if self.has_text_modality:
|
||||
|
Loading…
Reference in New Issue
Block a user