From 99a2771189321c826ff55d161a7cfedadd4023c7 Mon Sep 17 00:00:00 2001 From: novice <44259234+novice03@users.noreply.github.com> Date: Wed, 26 Jan 2022 12:18:29 -0600 Subject: [PATCH] Add YOSO (#15091) * Add cookiecutter files * Add cuda kernels and cpp files * Update modeling_yoso.py * Add .h files * Update configuration_yoso.py * Updates * Remove tokenizer * Code quality * Update modeling_yoso.py * Update modeling_yoso.py * Fix failing test * Update modeling_yoso.py * Fix code quality * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review and fix integration tests * Update src/transformers/models/yoso/modeling_yoso.py Co-authored-by: Patrick von Platen * Apply suggestions from code review * Fix copied from statement * Fix docstring * Fix code quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions and fix mask * Apply suggestions from code review * Fix code quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix docstrings * Fix code quality * Remove trailing whitespace * Update yoso.mdx * Move kernel loading to YosoEncoder * make style * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/yoso/modeling_yoso.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add short summary to docs * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update yoso.mdx * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Remove CausalLM model and add copied from * Remove autoregressive code * Remove unused imports * add copied from for embeddings * Fix code quality * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestion from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen --- README.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/_toctree.yml | 2 + docs/source/index.mdx | 2 + docs/source/model_doc/yoso.mdx | 91 ++ src/transformers/__init__.py | 26 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 7 + src/transformers/models/yoso/__init__.py | 62 + src/transformers/models/yoso/common.h | 10 + src/transformers/models/yoso/common_cuda.h | 9 + .../models/yoso/common_cuda_device.h | 79 + .../models/yoso/configuration_yoso.py | 145 ++ .../yoso/convert_yoso_pytorch_to_pytorch.py | 109 ++ .../models/yoso/fast_lsh_cumulation.cu | 588 ++++++++ .../models/yoso/fast_lsh_cumulation.h | 71 + .../models/yoso/fast_lsh_cumulation_cuda.cu | 825 ++++++++++ .../models/yoso/fast_lsh_cumulation_cuda.h | 157 ++ .../models/yoso/fast_lsh_cumulation_torch.cpp | 128 ++ src/transformers/models/yoso/modeling_yoso.py | 1324 +++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 59 + tests/test_modeling_yoso.py | 401 +++++ 25 files changed, 4103 insertions(+) create mode 100644 docs/source/model_doc/yoso.mdx create mode 100644 src/transformers/models/yoso/__init__.py create mode 100644 src/transformers/models/yoso/common.h create mode 100644 src/transformers/models/yoso/common_cuda.h create mode 100644 src/transformers/models/yoso/common_cuda_device.h create mode 100644 src/transformers/models/yoso/configuration_yoso.py create mode 100644 src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py create mode 100644 src/transformers/models/yoso/fast_lsh_cumulation.cu create mode 100644 src/transformers/models/yoso/fast_lsh_cumulation.h create mode 100644 src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu create mode 100644 src/transformers/models/yoso/fast_lsh_cumulation_cuda.h create mode 100644 src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp create mode 100644 src/transformers/models/yoso/modeling_yoso.py create mode 100644 tests/test_modeling_yoso.py diff --git a/README.md b/README.md index 0e1baa4ebdb..e8ec2e51507 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,7 @@ AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Ch 1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. 1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. 1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli. +1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. 1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. To check if each model has an implementation in Flax, PyTorch or TensorFlow, or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/docs/transformers/index#supported-frameworks). diff --git a/README_ko.md b/README_ko.md index 2aa78c99701..9e0cdde6914 100644 --- a/README_ko.md +++ b/README_ko.md @@ -303,6 +303,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. 1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli. 1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. +1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. 1. 새로운 모델을 올리고 싶나요? 우리가 **상세한 가이드와 템플릿** 으로 새로운 모델을 올리도록 도와드릴게요. 가이드와 템플릿은 이 저장소의 [`templates`](./templates) 폴더에서 확인하실 수 있습니다. [컨트리뷰션 가이드라인](./CONTRIBUTING.md)을 꼭 확인해주시고, PR을 올리기 전에 메인테이너에게 연락하거나 이슈를 오픈해 피드백을 받으시길 바랍니다. 각 모델이 Flax, PyTorch, TensorFlow으로 구현되었는지 또는 🤗 Tokenizers 라이브러리가 지원하는 토크나이저를 사용하는지 확인하려면, [이 표](https://huggingface.co/docs/transformers/index#supported-frameworks)를 확인하세요. diff --git a/README_zh-hans.md b/README_zh-hans.md index c3fe02ba1db..65ec8c57414 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -327,6 +327,7 @@ conda install -c huggingface transformers 1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (来自 Google/CMU) 伴随论文 [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) 由 Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 发布。 1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (来自 Facebook AI) 伴随论文 [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) 由 Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli 发布。 1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (来自 Facebook AI) 伴随论文 [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) 由 Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli 发布。 +1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (来自 the University of Wisconsin - Madison) 伴随论文 [You Only Sample (Almost) 由 Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh 发布。 1. 想要贡献新的模型?我们这里有一份**详细指引和模板**来引导你添加新的模型。你可以在 [`templates`](./templates) 目录中找到他们。记得查看 [贡献指南](./CONTRIBUTING.md) 并在开始写 PR 前联系维护人员或开一个新的 issue 来获得反馈。 要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器(tokenizer),敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。 diff --git a/README_zh-hant.md b/README_zh-hant.md index ff2ed9eb4cc..236d3791461 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -339,6 +339,7 @@ conda install -c huggingface transformers 1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. 1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli. 1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. +1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. 1. 想要貢獻新的模型?我們這裡有一份**詳細指引和模板**來引導你加入新的模型。你可以在 [`templates`](./templates) 目錄中找到它們。記得查看[貢獻指引](./CONTRIBUTING.md)並在開始寫 PR 前聯繫維護人員或開一個新的 issue 來獲得 feedbacks。 要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer,敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。 diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2ed4764d0e4..f21d92ce4ca 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -316,6 +316,8 @@ title: XLSR-Wav2Vec2 - local: model_doc/xls_r title: XLS-R + - local: model_doc/yoso + title: YOSO title: Models - sections: - local: internal/modeling_utils diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 7d9c686e2d3..abc0f996d93 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -184,6 +184,7 @@ conversion utilities for the following models. 1. **[XLNet](model_doc/xlnet)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. 1. **[XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. 1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli. +1. **[YOSO](model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. ### Supported frameworks @@ -281,5 +282,6 @@ Flax), PyTorch, and/or TensorFlow. | XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | XLNet | ✅ | ✅ | ✅ | ✅ | ❌ | +| YOSO | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/model_doc/yoso.mdx b/docs/source/model_doc/yoso.mdx new file mode 100644 index 00000000000..997ab4d0941 --- /dev/null +++ b/docs/source/model_doc/yoso.mdx @@ -0,0 +1,91 @@ + + +# YOSO + +## Overview + +The YOSO model was proposed in [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) +by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. YOSO approximates standard softmax self-attention +via a Bernoulli sampling scheme based on Locality Sensitive Hashing (LSH). In principle, all the Bernoulli random variables can be sampled with +a single hash. + +The abstract from the paper is the following: + +*Transformer-based models are widely used in natural language processing (NLP). Central to the transformer model is +the self-attention mechanism, which captures the interactions of token pairs in the input sequences and depends quadratically +on the sequence length. Training such models on longer sequences is expensive. In this paper, we show that a Bernoulli sampling +attention mechanism based on Locality Sensitive Hashing (LSH), decreases the quadratic complexity of such models to linear. +We bypass the quadratic cost by considering self-attention as a sum of individual tokens associated with Bernoulli random +variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant). +This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of +LSH (to enable deployment on GPU architectures). We evaluate our algorithm on the GLUE benchmark with standard 512 sequence +length where we see favorable performance relative to a standard pretrained Transformer. On the Long Range Arena (LRA) benchmark, +for evaluating performance on long sequences, our method achieves results consistent with softmax self-attention but with sizable +speed-ups and memory savings and often outperforms other efficient self-attention methods. Our code is available at this https URL* + +Tips: + +- The YOSO attention algorithm is implemented through custom CUDA kernels, functions written in CUDA C++ that can be executed multiple times +in parallel on a GPU. +- The kernels provide a `fast_hash` function, which approximates the random projections of the queries and keys using the Fast Hadamard Transform. Using these +hash codes, the `lsh_cumulation` function approximates self-attention via LSH-based Bernoulli sampling. +- To use the custom kernels, the user should set `config.use_expectation = False`. To ensure that the kernels are compiled successfully, +the user must install the correct version of PyTorch and cudatoolkit. By default, `config.use_expectation = True`, which uses YOSO-E and +does not require compiling CUDA kernels. + + + + YOSO Attention Algorithm. Taken from the original paper. + +This model was contributed by [novice03](https://huggingface.co/novice03). The original code can be found [here](https://github.com/mlpen/YOSO). + + +## YosoConfig + +[[autodoc]] YosoConfig + + +## YosoModel + +[[autodoc]] YosoModel + - forward + + +## YosoForMaskedLM + +[[autodoc]] YosoForMaskedLM + - forward + + +## YosoForSequenceClassification + +[[autodoc]] YosoForSequenceClassification + - forward + +## YosoForMultipleChoice + +[[autodoc]] YosoForMultipleChoice + - forward + + +## YosoForTokenClassification + +[[autodoc]] YosoForTokenClassification + - forward + + +## YosoForQuestionAnswering + +[[autodoc]] YosoForQuestionAnswering + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 95bbdd32850..190a1a41a58 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -333,6 +333,7 @@ _import_structure = { "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], "models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"], + "models.yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"], "onnx": [], "pipelines": [ "AudioClassificationPipeline", @@ -1510,6 +1511,19 @@ if is_torch_available(): "load_tf_weights_in_xlnet", ] ) + _import_structure["models.yoso"].extend( + [ + "YOSO_PRETRAINED_MODEL_ARCHIVE_LIST", + "YosoForMaskedLM", + "YosoForMultipleChoice", + "YosoForQuestionAnswering", + "YosoForSequenceClassification", + "YosoForTokenClassification", + "YosoLayer", + "YosoModel", + "YosoPreTrainedModel", + ] + ) _import_structure["optimization"] = [ "Adafactor", "AdamW", @@ -2454,6 +2468,7 @@ if TYPE_CHECKING: from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig + from .models.yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig # Pipelines from .pipelines import ( @@ -3431,6 +3446,17 @@ if TYPE_CHECKING: XLNetPreTrainedModel, load_tf_weights_in_xlnet, ) + from .models.yoso import ( + YOSO_PRETRAINED_MODEL_ARCHIVE_LIST, + YosoForMaskedLM, + YosoForMultipleChoice, + YosoForQuestionAnswering, + YosoForSequenceClassification, + YosoForTokenClassification, + YosoLayer, + YosoModel, + YosoPreTrainedModel, + ) # Optimization from .optimization import ( diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index db7e1d5db21..aff62b017ca 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -119,4 +119,5 @@ from . import ( xlm_prophetnet, xlm_roberta, xlnet, + yoso, ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index cbc2e57a069..2a88de258cc 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -30,6 +30,7 @@ logger = logging.get_logger(__name__) CONFIG_MAPPING_NAMES = OrderedDict( [ # Add configs here + ("yoso", "YosoConfig"), ("swin", "SwinConfig"), ("vilt", "ViltConfig"), ("vit_mae", "ViTMAEConfig"), @@ -121,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( [ # Add archive maps here + ("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -200,6 +202,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("yoso", "YOSO"), ("swin", "Swin"), ("vilt", "ViLT"), ("vit_mae", "ViTMAE"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index c35d4a647f0..5e114a7b120 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -28,6 +28,7 @@ logger = logging.get_logger(__name__) MODEL_MAPPING_NAMES = OrderedDict( [ # Base model mapping + ("yoso", "YosoModel"), ("swin", "SwinModel"), ("vilt", "ViltModel"), ("vit_mae", "ViTMAEModel"), @@ -155,6 +156,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( [ # Model with LM heads mapping + ("yoso", "YosoForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"), ("qdqbert", "QDQBertForMaskedLM"), ("fnet", "FNetForMaskedLM"), @@ -284,6 +286,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping + ("yoso", "YosoForMaskedLM"), ("nystromformer", "NystromformerForMaskedLM"), ("perceiver", "PerceiverForMaskedLM"), ("qdqbert", "QDQBertForMaskedLM"), @@ -357,6 +360,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Sequence Classification mapping + ("yoso", "YosoForSequenceClassification"), ("nystromformer", "NystromformerForSequenceClassification"), ("perceiver", "PerceiverForSequenceClassification"), ("qdqbert", "QDQBertForSequenceClassification"), @@ -405,6 +409,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Question Answering mapping + ("yoso", "YosoForQuestionAnswering"), ("nystromformer", "NystromformerForQuestionAnswering"), ("qdqbert", "QDQBertForQuestionAnswering"), ("fnet", "FNetForQuestionAnswering"), @@ -454,6 +459,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping + ("yoso", "YosoForTokenClassification"), ("nystromformer", "NystromformerForTokenClassification"), ("qdqbert", "QDQBertForTokenClassification"), ("fnet", "FNetForTokenClassification"), @@ -490,6 +496,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( [ # Model for Multiple Choice mapping + ("yoso", "YosoForMultipleChoice"), ("nystromformer", "NystromformerForMultipleChoice"), ("qdqbert", "QDQBertForMultipleChoice"), ("fnet", "FNetForMultipleChoice"), diff --git a/src/transformers/models/yoso/__init__.py b/src/transformers/models/yoso/__init__.py new file mode 100644 index 00000000000..6b1c7eb5aed --- /dev/null +++ b/src/transformers/models/yoso/__init__.py @@ -0,0 +1,62 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +# rely on isort to merge the imports +from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available + + +_import_structure = { + "configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"], +} + +if is_torch_available(): + _import_structure["modeling_yoso"] = [ + "YOSO_PRETRAINED_MODEL_ARCHIVE_LIST", + "YosoForMaskedLM", + "YosoForMultipleChoice", + "YosoForQuestionAnswering", + "YosoForSequenceClassification", + "YosoForTokenClassification", + "YosoLayer", + "YosoModel", + "YosoPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig + + if is_torch_available(): + from .modeling_yoso import ( + YOSO_PRETRAINED_MODEL_ARCHIVE_LIST, + YosoForMaskedLM, + YosoForMultipleChoice, + YosoForQuestionAnswering, + YosoForSequenceClassification, + YosoForTokenClassification, + YosoLayer, + YosoModel, + YosoPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/yoso/common.h b/src/transformers/models/yoso/common.h new file mode 100644 index 00000000000..e5085c88dd3 --- /dev/null +++ b/src/transformers/models/yoso/common.h @@ -0,0 +1,10 @@ + +#define min(a, b) ((a)<(b)?(a):(b)) +#define max(a, b) ((a)>(b)?(a):(b)) +#define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0)) +#define select(cond, a, b) ((cond)?(a):(b)) +#define PI 3.141592 +#define EPSILON 1e-8 +#define MAX_VAL 1e12 +#define MIN_VAL -1e12 +#define EMPTY_VALUE -1 diff --git a/src/transformers/models/yoso/common_cuda.h b/src/transformers/models/yoso/common_cuda.h new file mode 100644 index 00000000000..97030870649 --- /dev/null +++ b/src/transformers/models/yoso/common_cuda.h @@ -0,0 +1,9 @@ + +#define MAX_THREADS_PER_BLOCK 1024 +#define OPTIMAL_THREADS_PER_BLOCK 256 +#define WARP_SIZE 32 +#define MAX_NUM_BLOCK_X 2147483647 +#define MAX_NUM_BLOCK_Y 65535 +#define MAX_NUM_BLOCK_Z 65535 +#define MAX_SHARED_MEM_PER_BLOCK 48000 +#define FULL_MASK 0xffffffff diff --git a/src/transformers/models/yoso/common_cuda_device.h b/src/transformers/models/yoso/common_cuda_device.h new file mode 100644 index 00000000000..6674f93afdc --- /dev/null +++ b/src/transformers/models/yoso/common_cuda_device.h @@ -0,0 +1,79 @@ + +#include "common.h" + +template +__device__ int set_insert(T *set, int set_size, T value) { + int slot = value % set_size; + int start_slot = slot; + while (true) { + T prev = atomicCAS(&set[slot], EMPTY_VALUE, value); + if (prev == EMPTY_VALUE || prev == value) { + return slot; + } + slot = (slot + 1) % set_size; + if (slot == start_slot) { + return -1; + } + } + return -1; +} + +template +__device__ int set_lookup(T *set, int set_size, T value) { + int slot = value % set_size; + int start_slot = slot; + while (true) { + if (set[slot] == value) { + return slot; + } + slot = (slot + 1) % set_size; + if (slot == start_slot) { + return -1; + } + } + return -1; +} + +template +__device__ void init_buffer(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) { + __syncthreads(); + for (int i = 0; i < buffer_size; i = i + num_threads) { + int offset_idx = i + thread_id; + if (offset_idx < buffer_size) { + buffer[offset_idx] = init_value; + } + } + __syncthreads(); +} + +template +__device__ void copy_data(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) { + __syncthreads(); + for (int i = 0; i < data_length; i = i + num_threads) { + int offset_idx = i + thread_id; + if (offset_idx < data_length) { + dist_pt[offset_idx] = src_pt[offset_idx]; + } + } + __syncthreads(); +} + +template +__device__ void init_buffer_nonblocking(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) { + for (int i = 0; i < buffer_size; i = i + num_threads) { + int offset_idx = i + thread_id; + if (offset_idx < buffer_size) { + buffer[offset_idx] = init_value; + } + } +} + +template +__device__ void copy_data_nonblocking(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) { + for (int i = 0; i < data_length; i = i + num_threads) { + int offset_idx = i + thread_id; + if (offset_idx < data_length) { + dist_pt[offset_idx] = src_pt[offset_idx]; + } + } +} diff --git a/src/transformers/models/yoso/configuration_yoso.py b/src/transformers/models/yoso/configuration_yoso.py new file mode 100644 index 00000000000..55e338ac5ac --- /dev/null +++ b/src/transformers/models/yoso/configuration_yoso.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" YOSO model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "uw-madison/yoso-4096": "https://huggingface.co/uw-madison/yoso-4096/resolve/main/config.json", + # See all YOSO models at https://huggingface.co/models?filter=yoso +} + + +class YosoConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`YosoModel`]. It is used to instantiate an YOSO + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the YOSO + [uw-madison/yoso-4096](https://huggingface.co/uw-madison/yoso-4096) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the YOSO model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`YosoModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`YosoModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. + use_expectation (*bool*, *optional*, defaults to *True*): + Whether or not to use YOSO Expectation. Overrides any effect of num_hash. + hash_code_len (`int`, *optional*, defaults to 9): + The length of hashes generated by the hash functions. + num_hash (`int`, *optional*, defaults to 64): + Number of hash functions used in [`YosoSelfAttention`]. + conv_window (`int`, *optional*, defaults to None): + Kernel size of depth-wise convolution. + use_fast_hash (*bool*, *optional*, defaults to *False*): + Whether or not to use custom cuda kernels which perform fast random projection via hadamard transform. + lsh_backward (*bool*, *optional*, defaults to *True*): + Whether or not to perform backpropagation using Locality Sensitive Hashing. + + Example: + + ```python + >>> from transformers import YosoModel, YosoConfig + + >>> # Initializing a YOSO uw-madison/yoso-4096 style configuration + >>> configuration = YosoConfig() + + >>> # Initializing a model from the uw-madison/yoso-4096 style configuration + >>> model = YosoModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "yoso" + + def __init__( + self, + vocab_size=50265, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=4096, + type_vocab_size=1, + initializer_range=0.02, + layer_norm_eps=1e-12, + position_embedding_type="absolute", + use_expectation=True, + hash_code_len=9, + num_hash=64, + conv_window=None, + use_fast_hash=True, + lsh_backward=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_expectation = use_expectation + self.hash_code_len = hash_code_len + self.num_hash = num_hash + self.conv_window = conv_window + self.use_fast_hash = use_fast_hash + self.lsh_backward = lsh_backward diff --git a/src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py b/src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py new file mode 100644 index 00000000000..2b9a2c7cd85 --- /dev/null +++ b/src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert YOSO checkpoints from the original repository. URL: https://github.com/mlpen/YOSO""" + +import argparse + +import torch + +from transformers import YosoConfig, YosoForMaskedLM + + +def rename_key(orig_key): + if "model" in orig_key: + orig_key = orig_key.replace("model.", "") + if "norm1" in orig_key: + orig_key = orig_key.replace("norm1", "attention.output.LayerNorm") + if "norm2" in orig_key: + orig_key = orig_key.replace("norm2", "output.LayerNorm") + if "norm" in orig_key: + orig_key = orig_key.replace("norm", "LayerNorm") + if "transformer" in orig_key: + layer_num = orig_key.split(".")[0].split("_")[-1] + orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}") + if "mha.attn" in orig_key: + orig_key = orig_key.replace("mha.attn", "attention.self") + if "mha" in orig_key: + orig_key = orig_key.replace("mha", "attention") + if "W_q" in orig_key: + orig_key = orig_key.replace("W_q", "self.query") + if "W_k" in orig_key: + orig_key = orig_key.replace("W_k", "self.key") + if "W_v" in orig_key: + orig_key = orig_key.replace("W_v", "self.value") + if "ff1" in orig_key: + orig_key = orig_key.replace("ff1", "intermediate.dense") + if "ff2" in orig_key: + orig_key = orig_key.replace("ff2", "output.dense") + if "ff" in orig_key: + orig_key = orig_key.replace("ff", "output.dense") + if "mlm_class" in orig_key: + orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder") + if "mlm" in orig_key: + orig_key = orig_key.replace("mlm", "cls.predictions.transform") + if "cls" not in orig_key: + orig_key = "yoso." + orig_key + + return orig_key + + +def convert_checkpoint_helper(max_position_embeddings, orig_state_dict): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if ("pooler" in key) or ("sen_class" in key): + continue + else: + orig_state_dict[rename_key(key)] = val + + orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"] + orig_state_dict["yoso.embeddings.position_ids"] = torch.arange(max_position_embeddings).expand((1, -1)) + 2 + + return orig_state_dict + + +def convert_yoso_checkpoint(checkpoint_path, yoso_config_file, pytorch_dump_path): + + orig_state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] + config = YosoConfig.from_json_file(yoso_config_file) + model = YosoForMaskedLM(config) + + new_state_dict = convert_checkpoint_helper(config.max_position_embeddings, orig_state_dict) + + print(model.load_state_dict(new_state_dict)) + model.eval() + model.save_pretrained(pytorch_dump_path) + + print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--pytorch_model_path", default=None, type=str, required=True, help="Path to YOSO pytorch checkpoint." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The json file for YOSO model config.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_yoso_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path) diff --git a/src/transformers/models/yoso/fast_lsh_cumulation.cu b/src/transformers/models/yoso/fast_lsh_cumulation.cu new file mode 100644 index 00000000000..c6b13e6cb5f --- /dev/null +++ b/src/transformers/models/yoso/fast_lsh_cumulation.cu @@ -0,0 +1,588 @@ +// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation.cu + +#include +#include +#include "fast_lsh_cumulation.h" +#include "fast_lsh_cumulation_cuda.h" +#include "common_cuda.h" +#include "common.h" +#include +////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector fast_hash_ver1_kernel( + at::Tensor query_mask, + at::Tensor query_vector, + at::Tensor key_mask, + at::Tensor key_vector, + int num_hash_f, + int hash_code_len, + bool use_cuda +) { + + int batch_size = query_vector.size(0); + int num_query = query_vector.size(1); + int num_key = key_vector.size(1); + int vector_dim = query_vector.size(2); + + int num_hash_per_part = vector_dim / hash_code_len; + int num_part = max(1, ceil_divide(num_hash_f, num_hash_per_part)); + + at::Tensor Dmat = 2 * at::randint(0, 2, {batch_size, 3, num_part, vector_dim}, query_mask.options()) - 1; + at::Tensor query_hash_code = at::zeros({batch_size, num_query, num_hash_f}, query_mask.options()); + at::Tensor key_hash_code = at::zeros({batch_size, num_key, num_hash_f}, key_mask.options()); + + int *query_mask_ptr = query_mask.data_ptr(); + float *query_vector_ptr = query_vector.data_ptr(); + int *key_mask_ptr = key_mask.data_ptr(); + float *key_vector_ptr = key_vector.data_ptr(); + + int *Dmat_ptr = Dmat.data_ptr(); + + int *query_hash_code_ptr = query_hash_code.data_ptr(); + int *key_hash_code_ptr = key_hash_code.data_ptr(); + + if (use_cuda) { + { + dim3 threads(vector_dim); + dim3 blocks(num_part, num_query, batch_size); + int shared_mem = vector_dim * sizeof(float); + fast_hash_ver1_cuda_kernel<<>>( + query_mask_ptr, + query_vector_ptr, + Dmat_ptr, + query_hash_code_ptr, + batch_size, + num_query, + vector_dim, + num_part, + num_hash_f, + hash_code_len + ); + } + { + dim3 threads(vector_dim); + dim3 blocks(num_part, num_key, batch_size); + int shared_mem = vector_dim * sizeof(float); + fast_hash_ver1_cuda_kernel<<>>( + key_mask_ptr, + key_vector_ptr, + Dmat_ptr, + key_hash_code_ptr, + batch_size, + num_key, + vector_dim, + num_part, + num_hash_f, + hash_code_len + ); + } + } + + return {query_hash_code, key_hash_code}; + +} + +at::Tensor lsh_cumulation_ver1_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +) { + + int batch_size = query_hash_code.size(0); + int num_hash_f = query_hash_code.size(2); + + int num_query = query_hash_code.size(1); + int num_key = key_hash_code.size(1); + int value_dim = value.size(2); + + at::Tensor hashtable_value = at::empty({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options()); + at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); + + if (use_cuda) { + int threads_x = WARP_SIZE; + int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE; + int block_x_step1 = num_key / threads_y; + int block_x_step2 = num_query / threads_y; + int block_y = batch_size; + + dim3 threads(threads_x, threads_y); + dim3 blocks_step1(block_x_step1, block_y); + dim3 blocks_step2(block_x_step2, block_y); + + int *query_mask_ptr = query_mask.data_ptr(); + int *query_hash_code_ptr = query_hash_code.data_ptr(); + int *key_mask_ptr = key_mask.data_ptr(); + int *key_hash_code_ptr = key_hash_code.data_ptr(); + float *value_ptr = value.data_ptr(); + float *hashtable_value_ptr = hashtable_value.data_ptr(); + float *cumulation_value_ptr = cumulation_value.data_ptr(); + + for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { + + cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float)); + + lsh_cumulation_ver1_step1_cuda_kernel<<>>( + key_mask_ptr, + key_hash_code_ptr, + value_ptr, + hashtable_value_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_key, + value_dim, + value_offset + ); + + lsh_cumulation_ver1_step2_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + hashtable_value_ptr, + cumulation_value_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query, + value_dim, + value_offset + ); + } + + } + + return cumulation_value; + +} + +at::Tensor lsh_weighted_cumulation_ver1_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +) { + + int batch_size = query_hash_code.size(0); + int num_hash_f = query_hash_code.size(2); + + int num_query = query_hash_code.size(1); + int num_key = key_hash_code.size(1); + int value_dim = value.size(2); + int weight_dim = query_weight.size(2); + + at::Tensor hashtable_value = at::zeros({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options()); + at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); + + if (use_cuda) { + int threads_x = WARP_SIZE; + int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE; + int block_x_step1 = num_key / threads_y; + int block_x_step2 = num_query / threads_y; + int block_y = batch_size; + + dim3 threads(threads_x, threads_y); + dim3 blocks_step1(block_x_step1, block_y); + dim3 blocks_step2(block_x_step2, block_y); + + int *query_mask_ptr = query_mask.data_ptr(); + int *query_hash_code_ptr = query_hash_code.data_ptr(); + float *query_weight_ptr = query_weight.data_ptr(); + int *key_mask_ptr = key_mask.data_ptr(); + int *key_hash_code_ptr = key_hash_code.data_ptr(); + float *key_weight_ptr = key_weight.data_ptr(); + float *value_ptr = value.data_ptr(); + float *hashtable_value_ptr = hashtable_value.data_ptr(); + float *cumulation_value_ptr = cumulation_value.data_ptr(); + + for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { + for (int weight_idx = 0; weight_idx < weight_dim; weight_idx++) { + + cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float)); + + lsh_weighted_cumulation_ver1_step1_cuda_kernel<<>>( + key_mask_ptr, + key_hash_code_ptr, + key_weight_ptr, + value_ptr, + hashtable_value_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_key, + value_dim, + weight_dim, + value_offset, + weight_idx + ); + + lsh_weighted_cumulation_ver1_step2_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + query_weight_ptr, + hashtable_value_ptr, + cumulation_value_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query, + value_dim, + weight_dim, + value_offset, + weight_idx + ); + } + } + + } + + return cumulation_value; + +} + +at::Tensor lsh_weighted_cumulation_ver2_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +) { + + int batch_size = query_hash_code.size(0); + int num_hash_f = query_hash_code.size(2); + + int num_query = query_hash_code.size(1); + int num_key = key_hash_code.size(1); + int value_dim = value.size(2); + int weight_dim = query_weight.size(2); + + at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options()); + at::Tensor key_sorted_idxes = at::zeros({batch_size, num_hash_f, num_key}, query_hash_code.options()); + at::Tensor query_info = at::zeros({batch_size, num_query, 2, num_hash_f}, query_hash_code.options()); + at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); + + if (use_cuda) { + + int *query_mask_ptr = query_mask.data_ptr(); + int *query_hash_code_ptr = query_hash_code.data_ptr(); + float *query_weight_ptr = query_weight.data_ptr(); + int *key_mask_ptr = key_mask.data_ptr(); + int *key_hash_code_ptr = key_hash_code.data_ptr(); + float *key_weight_ptr = key_weight.data_ptr(); + float *value_ptr = value.data_ptr(); + + int *count_sort_table_ptr = count_sort_table.data_ptr(); + int *key_sorted_idxes_ptr = key_sorted_idxes.data_ptr(); + int *query_info_ptr = query_info.data_ptr(); + + float *cumulation_value_ptr = cumulation_value.data_ptr(); + + { + dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); + dim3 blocks_step13(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); + dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK)); + dim3 blocks_step2(num_hash_f, batch_size); + int shared_mem = hashtable_capacity * sizeof(float); + count_sort_step1_cuda_kernel<<>>( + key_mask_ptr, + key_hash_code_ptr, + count_sort_table_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_key + ); + count_sort_step2_cuda_kernel<<>>( + count_sort_table_ptr, + batch_size, + num_hash_f, + hashtable_capacity + ); + count_sort_step3_cuda_kernel<<>>( + key_mask_ptr, + key_hash_code_ptr, + count_sort_table_ptr, + key_sorted_idxes_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_key + ); + } + { + dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); + dim3 blocks(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); + extract_query_info_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + count_sort_table_ptr, + query_info_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query + ); + } + { + dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE); + dim3 blocks(num_query, num_hash_f, batch_size); + int shared_mem = (weight_dim + WARP_SIZE) * sizeof(float); + lsh_weighted_cumulation_ver2_step2_cuda_kernel<<>>( + query_mask_ptr, + query_info_ptr, + key_sorted_idxes_ptr, + query_weight_ptr, + key_weight_ptr, + value_ptr, + cumulation_value_ptr, + batch_size, + num_hash_f, + num_query, + num_key, + value_dim, + weight_dim + ); + } + } + + return cumulation_value; + +} + +at::Tensor lsh_weighted_cumulation_ver3_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +) { + + int batch_size = query_hash_code.size(0); + int num_hash_f = query_hash_code.size(2); + + int num_query = query_hash_code.size(1); + int num_key = key_hash_code.size(1); + int value_dim = value.size(2); + int weight_dim = query_weight.size(2); + + at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options()); + at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options()); + at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options()); + at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); + + if (use_cuda) { + + int *query_mask_ptr = query_mask.data_ptr(); + int *query_hash_code_ptr = query_hash_code.data_ptr(); + float *query_weight_ptr = query_weight.data_ptr(); + int *key_mask_ptr = key_mask.data_ptr(); + int *key_hash_code_ptr = key_hash_code.data_ptr(); + float *key_weight_ptr = key_weight.data_ptr(); + float *value_ptr = value.data_ptr(); + + int *count_sort_table_ptr = count_sort_table.data_ptr(); + int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr(); + int *key_info_ptr = key_info.data_ptr(); + + float *cumulation_value_ptr = cumulation_value.data_ptr(); + + { + dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); + dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); + dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK)); + dim3 blocks_step2(num_hash_f, batch_size); + int shared_mem = hashtable_capacity * sizeof(float); + count_sort_step1_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + count_sort_table_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query + ); + count_sort_step2_cuda_kernel<<>>( + count_sort_table_ptr, + batch_size, + num_hash_f, + hashtable_capacity + ); + count_sort_step3_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + count_sort_table_ptr, + query_sorted_idxes_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query + ); + } + { + dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); + dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); + extract_query_info_cuda_kernel<<>>( + key_mask_ptr, + key_hash_code_ptr, + count_sort_table_ptr, + key_info_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_key + ); + } + { + dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE); + dim3 blocks(num_key, num_hash_f, batch_size); + int shared_mem = (weight_dim + value_dim + WARP_SIZE) * sizeof(float); + lsh_weighted_cumulation_ver3_step2_cuda_kernel<<>>( + query_sorted_idxes_ptr, + key_mask_ptr, + key_info_ptr, + query_weight_ptr, + key_weight_ptr, + value_ptr, + cumulation_value_ptr, + batch_size, + num_hash_f, + num_query, + num_key, + value_dim, + weight_dim + ); + } + } + + return cumulation_value; + +} + +at::Tensor lsh_weighted_cumulation_ver4_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +) { + + int batch_size = query_hash_code.size(0); + int num_hash_f = query_hash_code.size(2); + + int num_query = query_hash_code.size(1); + int num_key = key_hash_code.size(1); + int value_dim = value.size(2); + int weight_dim = query_weight.size(2); + + at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options()); + at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options()); + at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options()); + at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options()); + + if (use_cuda) { + + int *query_mask_ptr = query_mask.data_ptr(); + int *query_hash_code_ptr = query_hash_code.data_ptr(); + float *query_weight_ptr = query_weight.data_ptr(); + int *key_mask_ptr = key_mask.data_ptr(); + int *key_hash_code_ptr = key_hash_code.data_ptr(); + float *key_weight_ptr = key_weight.data_ptr(); + float *value_ptr = value.data_ptr(); + + int *count_sort_table_ptr = count_sort_table.data_ptr(); + int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr(); + int *key_info_ptr = key_info.data_ptr(); + + float *cumulation_value_ptr = cumulation_value.data_ptr(); + + { + dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); + dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); + dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK)); + dim3 blocks_step2(num_hash_f, batch_size); + int shared_mem = hashtable_capacity * sizeof(float); + count_sort_step1_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + count_sort_table_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query + ); + count_sort_step2_cuda_kernel<<>>( + count_sort_table_ptr, + batch_size, + num_hash_f, + hashtable_capacity + ); + count_sort_step3_cuda_kernel<<>>( + query_mask_ptr, + query_hash_code_ptr, + count_sort_table_ptr, + query_sorted_idxes_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_query + ); + } + { + dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f)); + dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size); + extract_query_info_cuda_kernel<<>>( + key_mask_ptr, + key_hash_code_ptr, + count_sort_table_ptr, + key_info_ptr, + batch_size, + num_hash_f, + hashtable_capacity, + num_key + ); + } + { + dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE); + dim3 blocks(num_key, batch_size); + int shared_mem = (weight_dim + value_dim + 2 * num_hash_f) * sizeof(float); + lsh_weighted_cumulation_ver4_step2_cuda_kernel<<>>( + query_sorted_idxes_ptr, + key_mask_ptr, + key_info_ptr, + query_weight_ptr, + key_weight_ptr, + value_ptr, + cumulation_value_ptr, + batch_size, + num_hash_f, + num_query, + num_key, + value_dim, + weight_dim + ); + } + } + + return cumulation_value; + +} diff --git a/src/transformers/models/yoso/fast_lsh_cumulation.h b/src/transformers/models/yoso/fast_lsh_cumulation.h new file mode 100644 index 00000000000..dd48de0ed15 --- /dev/null +++ b/src/transformers/models/yoso/fast_lsh_cumulation.h @@ -0,0 +1,71 @@ +#include +#include +#include + +std::vector fast_hash_ver1_kernel( + at::Tensor query_mask, + at::Tensor query_vector, + at::Tensor key_mask, + at::Tensor key_vector, + int num_hash_f, + int hash_code_len, + bool use_cuda +); + +at::Tensor lsh_cumulation_ver1_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +); + +at::Tensor lsh_weighted_cumulation_ver1_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +); + +at::Tensor lsh_weighted_cumulation_ver2_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +); + +at::Tensor lsh_weighted_cumulation_ver3_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +); + +at::Tensor lsh_weighted_cumulation_ver4_kernel( + at::Tensor query_mask, + at::Tensor query_hash_code, + at::Tensor query_weight, + at::Tensor key_mask, + at::Tensor key_hash_code, + at::Tensor key_weight, + at::Tensor value, + int hashtable_capacity, + bool use_cuda +); diff --git a/src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu b/src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu new file mode 100644 index 00000000000..ebc6260dd6d --- /dev/null +++ b/src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu @@ -0,0 +1,825 @@ +// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation_cuda.cu + +#include "fast_lsh_cumulation_cuda.h" +#include "common_cuda_device.h" +#include "common_cuda.h" +#include "common.h" +#include +////////////////////////////////////////////////////////////////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void fast_hadamard_transform(float *vector_buffer, int vector_dim, int dim_idx) { + int stride = vector_dim / 2; + while (stride > (WARP_SIZE / 2)) { + __syncthreads(); + int sign = 1 - ((dim_idx / stride) % 2) * 2; + float val1 = vector_buffer[dim_idx]; + float val2 = vector_buffer[dim_idx + sign * stride]; + __syncthreads(); + vector_buffer[dim_idx] = float(sign) * val1 + val2; + stride = stride / 2; + } + + float val = vector_buffer[dim_idx]; + #pragma unroll + for (stride = (WARP_SIZE / 2); stride > 0; stride = stride / 2) { + int sign = 1 - ((dim_idx / stride) % 2) * 2; + val = float(sign) * val + __shfl_xor_sync(FULL_MASK, val, stride); + } + vector_buffer[dim_idx] = val; +} + +__global__ void fast_hash_ver1_cuda_kernel( + int *mask, // [batch_size, num_vector] + float *vector, // [batch_size, num_vector, vector_dim] + int *Dmat, // [batch_size, 3, num_part, vector_dim] + int *hash_code, // [batch_size, num_vector, num_hash_f] + int batch_size, + int num_vector, + int vector_dim, + int num_part, + int num_hash_f, + int hash_code_len +) { + + int batch_idx = blockIdx.z; + int vector_idx = blockIdx.y; + int part_idx = blockIdx.x; + + int dim_idx = threadIdx.x; + + int batch_idx__vector_idx = batch_idx * num_vector + vector_idx; + if (mask[batch_idx__vector_idx] == 0) { + return; + } + + extern __shared__ float buffer[]; + float *vector_buffer = buffer; + + vector_buffer[dim_idx] = vector[batch_idx__vector_idx * vector_dim + dim_idx]; + + vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 0) * num_part + part_idx) * vector_dim + dim_idx]; + fast_hadamard_transform(vector_buffer, vector_dim, dim_idx); + vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 1) * num_part + part_idx) * vector_dim + dim_idx]; + fast_hadamard_transform(vector_buffer, vector_dim, dim_idx); + vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 2) * num_part + part_idx) * vector_dim + dim_idx]; + fast_hadamard_transform(vector_buffer, vector_dim, dim_idx); + + int num_hash_per_part = vector_dim / hash_code_len; + if (hash_code_len == 8 || hash_code_len == 16) { + int code = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0); + for (int offset = 1; offset < hash_code_len; offset = offset * 2) { + code += __shfl_xor_sync(FULL_MASK, code, offset); + } + if (dim_idx % hash_code_len == 0) { + int hash_f_idx = part_idx * num_hash_per_part + dim_idx / hash_code_len; + if (hash_f_idx < num_hash_f) { + hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code; + } + } + } else { + vector_buffer[dim_idx] = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0); + __syncthreads(); + if (dim_idx < num_hash_per_part) { + int code = 0; + for (int i = 0; i < hash_code_len; i++) { + code += vector_buffer[dim_idx * hash_code_len + i]; + } + int hash_f_idx = part_idx * num_hash_per_part + dim_idx; + if (hash_f_idx < num_hash_f) { + hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code; + } + } + } +} + +__global__ void lsh_cumulation_ver1_step1_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + float *value, // [batch_size, num_key, value_dim] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key, + int value_dim, + int offset_warp +) { + + int warp_thread_idx = threadIdx.x; + + int batch_idx = blockIdx.y; + int key_idx = blockIdx.x * blockDim.y + threadIdx.y; + + int batch_idx__key_idx = batch_idx * num_key + key_idx; + if (key_mask[batch_idx__key_idx] == 0) { + return; + } + + if (num_hash_f > WARP_SIZE) { + float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; + for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { + int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx]; + #pragma unroll + for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); + int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; + atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); + } + } + } else { + float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; + int warp_hashcode = 0; + if (warp_thread_idx < num_hash_f) { + warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx]; + } + for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); + int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; + atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); + } + } + +} + +__global__ void lsh_cumulation_ver1_step2_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_hash_code, // [batch_size, num_query, num_hash_f] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_query, + int value_dim, + int offset_warp +) { + + int warp_thread_idx = threadIdx.x; + + int batch_idx = blockIdx.y; + int query_idx = blockIdx.x * blockDim.y + threadIdx.y; + + int batch_idx__query_idx = batch_idx * num_query + query_idx; + if (query_mask[batch_idx__query_idx] == 0) { + return; + } + + if (num_hash_f > WARP_SIZE) { + float warp_value = 0; + for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { + int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx]; + #pragma unroll + for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); + int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; + warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; + } + } + cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f); + } else { + float warp_value = 0; + int warp_hashcode = 0; + if (warp_thread_idx < num_hash_f) { + warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx]; + } + for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); + int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; + warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; + } + cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f); + } + +} + +__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key, + int value_dim, + int weight_dim, + int offset_warp, + int weight_idx +) { + + int warp_thread_idx = threadIdx.x; + + int batch_idx = blockIdx.y; + int key_idx = blockIdx.x * blockDim.y + threadIdx.y; + + int batch_idx__key_idx = batch_idx * num_key + key_idx; + if (key_mask[batch_idx__key_idx] == 0) { + return; + } + + if (num_hash_f > WARP_SIZE) { + float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; + for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { + int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx]; + #pragma unroll + for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); + int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; + atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); + } + } + } else { + float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx]; + int warp_hashcode = 0; + if (warp_thread_idx < num_hash_f) { + warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx]; + } + for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); + int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; + atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value); + } + } + +} + +__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_hash_code, // [batch_size, num_query, num_hash_f] + float *query_weight, // [batch_size, num_query, weight_dim] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_query, + int value_dim, + int weight_dim, + int offset_warp, + int weight_idx +) { + + int warp_thread_idx = threadIdx.x; + + int batch_idx = blockIdx.y; + int query_idx = blockIdx.x * blockDim.y + threadIdx.y; + + int batch_idx__query_idx = batch_idx * num_query + query_idx; + if (query_mask[batch_idx__query_idx] == 0) { + return; + } + + if (num_hash_f > WARP_SIZE) { + float warp_value = 0; + for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) { + int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx]; + #pragma unroll + for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset); + int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode; + warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; + } + } + float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx]; + cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f); + } else { + float warp_value = 0; + int warp_hashcode = 0; + if (warp_thread_idx < num_hash_f) { + warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx]; + } + for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) { + int current_hashcode = warp_hashcode; + current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx); + int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode; + warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx]; + } + float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx]; + cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f); + } + +} + +__global__ void count_sort_step1_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key +) { + + int batch_idx = blockIdx.y; + int key_idx = blockIdx.x * blockDim.y + threadIdx.y; + int hash_f_idx = threadIdx.x; + + int batch_idx__key_idx = batch_idx * num_key + key_idx; + if (key_mask[batch_idx__key_idx] == 0) { + return; + } + + int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx]; + atomicAdd(&count_sort_table[(batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code], 1); + +} + +__global__ void count_sort_step2_cuda_kernel( + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int batch_size, + int num_hash_f, + int hashtable_capacity +) { + + int batch_idx = blockIdx.y; + int hash_f_idx = blockIdx.x; + + int num_threads = blockDim.x; + int thread_id = threadIdx.x; + + int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx; + + extern __shared__ float buffer[]; + int *table_buffer = (int*)buffer; + + if (thread_id == 0) { + table_buffer[0] = 0; + } + copy_data(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], &table_buffer[1], hashtable_capacity - 1, num_threads, thread_id); + + for (int table_idx_start = 0; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + num_threads) { + int thread_value = table_buffer[table_idx_start + thread_id]; + int next_thread_value = 0; + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + next_thread_value = __shfl_up_sync(FULL_MASK, thread_value, offset); + if (thread_id % WARP_SIZE >= offset) { + thread_value = thread_value + next_thread_value; + } + } + table_buffer[table_idx_start + thread_id] = thread_value; + } + __syncthreads(); + + if (hashtable_capacity > WARP_SIZE) { + if (thread_id < WARP_SIZE) { + for (int table_idx_start = WARP_SIZE; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + WARP_SIZE) { + table_buffer[table_idx_start + thread_id] += table_buffer[table_idx_start - 1]; + } + } + } + + copy_data(table_buffer, &count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], hashtable_capacity, num_threads, thread_id); + +} + + +__global__ void count_sort_step3_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int *key_sorted_idxes, // [batch_size, num_hash_f, num_key] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key +) { + + int batch_idx = blockIdx.y; + int key_idx = blockIdx.x * blockDim.y + threadIdx.y; + int hash_f_idx = threadIdx.x; + + int batch_idx__key_idx = batch_idx * num_key + key_idx; + if (key_mask[batch_idx__key_idx] == 0) { + return; + } + + int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx; + + int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx]; + int sort_idx = atomicAdd(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity + hash_code], 1); + key_sorted_idxes[batch_idx__hash_f_idx * num_key + sort_idx] = key_idx; + +} + +__global__ void extract_query_info_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_hash_code, // [batch_size, num_query, num_hash_f] + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int *query_info, // [batch_size, num_query, 2, num_hash_f] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_query +) { + + int batch_idx = blockIdx.y; + int query_idx = blockIdx.x * blockDim.y + threadIdx.y; + int hash_f_idx = threadIdx.x; + + int batch_idx__query_idx = batch_idx * num_query + query_idx; + if (query_mask[batch_idx__query_idx] == 0) { + return; + } + + int hash_code = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_idx]; + int batch_idx__hash_f_idx__hash_code = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code; + + int key_offset = select(hash_code == 0, 0, count_sort_table[batch_idx__hash_f_idx__hash_code - 1]); + int key_count = count_sort_table[batch_idx__hash_f_idx__hash_code] - key_offset; + + query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx] = key_offset; + query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx] = key_count; + +} + +__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_info, // [batch_size, num_query, 2, num_hash_f] + int *key_sorted_idxes, // [batch_size, num_hash_f, num_key] + float *query_weight, // [batch_size, num_query, weight_dim] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int num_query, + int num_key, + int value_dim, + int weight_dim +) { + + int batch_idx = blockIdx.z; + int hash_f_idx = blockIdx.y; + int query_idx = blockIdx.x; + + int num_threads = blockDim.y * blockDim.x; + int thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + int num_warps = blockDim.y; + int warp_idx = threadIdx.y; + int warp_thread_idx = threadIdx.x; + + int batch_idx__query_idx = batch_idx * num_query + query_idx; + if (query_mask[batch_idx__query_idx] == 0) { + return; + } + + int key_offset = query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx]; + int key_count = query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx]; + + if (key_count == 0) { + return; + } + + extern __shared__ float buffer[]; + + if (key_count == 1) { + if (warp_idx == 0) { + int key_idx = key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset]; + int batch_idx__key_idx = batch_idx * num_key + key_idx; + float weight = 0; + for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { + int weight_dim_idx = weight_offset + warp_thread_idx; + float val = query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx]; + #pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + val += __shfl_xor_sync(FULL_MASK, val, offset); + } + weight = weight + val; + } + weight = weight / float(num_hash_f); + for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { + int value_dim_idx = value_offset + warp_thread_idx; + float val = value[batch_idx__key_idx * value_dim + value_dim_idx]; + atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); + } + } + } else { + float *weight_buffer = buffer; + int *key_idxes_buffer = (int*)&buffer[weight_dim]; + + copy_data_nonblocking(&query_weight[batch_idx__query_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id); + + while (key_count > 0) { + int work_size = min(WARP_SIZE, key_count); + copy_data_nonblocking(&key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset], key_idxes_buffer, work_size, num_threads, thread_id); + __syncthreads(); + for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) { + int work_idx = work_offset + warp_idx; + if (work_idx < key_count) { + int key_idx = key_idxes_buffer[work_idx]; + int batch_idx__key_idx = batch_idx * num_key + key_idx; + float weight = 0; + for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { + int weight_dim_idx = weight_offset + warp_thread_idx; + float val = weight_buffer[weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx]; + #pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + val += __shfl_xor_sync(FULL_MASK, val, offset); + } + weight = weight + val; + } + weight = weight / float(num_hash_f); + for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { + int value_dim_idx = value_offset + warp_thread_idx; + float val = value[batch_idx__key_idx * value_dim + value_dim_idx]; + atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); + } + } + } + key_count = key_count - work_size; + key_offset = key_offset + work_size; + } + } + +} + +__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel( + int *query_sorted_idxes, // [batch_size, num_hash_f, num_query] + int *key_mask, // [batch_size, num_key] + int *key_info, // [batch_size, num_key, 2, num_hash_f] + float *query_weight, // [batch_size, num_query, weight_dim] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int num_query, + int num_key, + int value_dim, + int weight_dim +) { + + int batch_idx = blockIdx.z; + int hash_f_idx = blockIdx.y; + int key_idx = blockIdx.x; + + int num_threads = blockDim.y * blockDim.x; + int thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + int num_warps = blockDim.y; + int warp_idx = threadIdx.y; + int warp_thread_idx = threadIdx.x; + + int batch_idx__key_idx = batch_idx * num_key + key_idx; + if (key_mask[batch_idx__key_idx] == 0) { + return; + } + + int query_offset = key_info[batch_idx__key_idx * 2 * num_hash_f + hash_f_idx]; + int query_count = key_info[(batch_idx__key_idx * 2 + 1) * num_hash_f + hash_f_idx]; + + if (query_count == 0) { + return; + } + + extern __shared__ float buffer[]; + + if (query_count == 1) { + if (warp_idx == 0) { + int query_idx = query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset]; + int batch_idx__query_idx = batch_idx * num_query + query_idx; + float weight = 0; + for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { + int weight_dim_idx = weight_offset + warp_thread_idx; + float val = key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx]; + #pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + val += __shfl_xor_sync(FULL_MASK, val, offset); + } + weight = weight + val; + } + weight = weight / float(num_hash_f); + for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { + int value_dim_idx = value_offset + warp_thread_idx; + float val = value[batch_idx__key_idx * value_dim + value_dim_idx]; + atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); + } + } + } else { + float *weight_buffer = buffer; + float *value_buffer = &buffer[weight_dim]; + int *query_idxes_buffer = (int*)&buffer[weight_dim + value_dim]; + + copy_data_nonblocking(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id); + copy_data_nonblocking(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id); + + while (query_count > 0) { + int work_size = min(WARP_SIZE, query_count); + copy_data_nonblocking(&query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset], query_idxes_buffer, work_size, num_threads, thread_id); + __syncthreads(); + for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) { + int work_idx = work_offset + warp_idx; + if (work_idx < query_count) { + int query_idx = query_idxes_buffer[work_idx]; + int batch_idx__query_idx = batch_idx * num_query + query_idx; + float weight = 0; + for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) { + int weight_dim_idx = weight_offset + warp_thread_idx; + float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx]; + #pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + val += __shfl_xor_sync(FULL_MASK, val, offset); + } + weight = weight + val; + } + weight = weight / float(num_hash_f); + for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) { + int value_dim_idx = value_offset + warp_thread_idx; + float val = value_buffer[value_dim_idx]; + atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); + } + } + } + query_count = query_count - work_size; + query_offset = query_offset + work_size; + } + } + +} + +__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel( + int *query_sorted_idxes, // [batch_size, num_hash_f, num_query] + int *key_mask, // [batch_size, num_key] + int *key_info, // [batch_size, num_key, 2, num_hash_f] + float *query_weight, // [batch_size, num_query, weight_dim] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int num_query, + int num_key, + int value_dim, + int weight_dim +) { + + int batch_idx = blockIdx.y; + int key_idx = blockIdx.x; + + int num_threads = blockDim.y * blockDim.x; + int thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + int num_warps = blockDim.y; + int warp_idx = threadIdx.y; + int warp_thread_idx = threadIdx.x; + + int batch_idx__key_idx = batch_idx * num_key + key_idx; + if (key_mask[batch_idx__key_idx] == 0) { + return; + } + + extern __shared__ float buffer[]; + float *weight_buffer = buffer; + float *value_buffer = &buffer[weight_dim]; + int *key_info_buffer = (int*)&buffer[weight_dim + value_dim]; + + copy_data_nonblocking(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id); + copy_data_nonblocking(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id); + copy_data_nonblocking(&key_info[batch_idx__key_idx * 2 * num_hash_f], key_info_buffer, 2 * num_hash_f, num_threads, thread_id); + + int *query_offset_buffer = key_info_buffer; + int *query_count_buffer = &key_info_buffer[num_hash_f]; + + const int hashtable_size = 1024 + OPTIMAL_THREADS_PER_BLOCK; + __shared__ int hashtable_query[hashtable_size]; + __shared__ int hashtable_count[hashtable_size]; + __shared__ int inserted_query[hashtable_size]; + __shared__ int query_counter[1]; + + int hash_f_idx_base = 0; + + while (true) { + + init_buffer_nonblocking(EMPTY_VALUE, hashtable_query, hashtable_size, num_threads, thread_id); + init_buffer_nonblocking(0, hashtable_count, hashtable_size, num_threads, thread_id); + init_buffer_nonblocking(EMPTY_VALUE, inserted_query, hashtable_size, num_threads, thread_id); + init_buffer_nonblocking(0, query_counter, 1, num_threads, thread_id); + __syncthreads(); + + while (hash_f_idx_base < num_hash_f) { + + int hash_f_idx = hash_f_idx_base + warp_idx; + int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx; + + int stop_flag = 0; + + int query_offset = query_offset_buffer[hash_f_idx]; + int query_count = query_count_buffer[hash_f_idx]; + + while (query_count > 0) { + + int work_size = min(query_count, WARP_SIZE); + + // try inserting query to set and check whether the query is new + int found_new_query = 0; + int query_idx = -1; + if (warp_thread_idx < work_size) { + query_idx = query_sorted_idxes[batch_idx__hash_f_idx * num_query + query_offset + warp_thread_idx]; + int slot = set_insert(hashtable_query, hashtable_size, query_idx); + if (slot >= 0) { + found_new_query = atomicAdd(&hashtable_count[slot], 1) == 0; + } + } + + // compute cumulative offset + int position_offset = found_new_query; + int next_position_offset = 0; + #pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + next_position_offset = __shfl_up_sync(FULL_MASK, position_offset, offset); + if (thread_id % WARP_SIZE >= offset) { + position_offset = position_offset + next_position_offset; + } + } + + // get the inserted query list end index + int inserted_query_base = 0; + if (thread_id % WARP_SIZE == WARP_SIZE - 1) { + inserted_query_base = atomicAdd(query_counter, position_offset); + } + inserted_query_base = __shfl_sync(FULL_MASK, inserted_query_base, WARP_SIZE - 1); + + // insert new queries to list + int insert_idx = inserted_query_base + position_offset - 1; + if (found_new_query) { + inserted_query[insert_idx] = query_idx; + } + + // remove inserted queries from list + query_offset_buffer[hash_f_idx] += work_size; + query_count_buffer[hash_f_idx] -= work_size; + query_offset += work_size; + query_count -= work_size; + + // if list is almost full, stop inserting + if (inserted_query_base + OPTIMAL_THREADS_PER_BLOCK > hashtable_size) { + stop_flag = 1; + break; + } + + } + + if (stop_flag) { + break; + } + + hash_f_idx_base = hash_f_idx_base + num_warps; + + } + + __syncthreads(); + + int num_distint_query = query_counter[0]; + + if (num_distint_query > 0) { + for (int idx_base = 0; idx_base < num_distint_query; idx_base = idx_base + num_warps) { + int idx = idx_base + warp_idx; + if (idx < num_distint_query) { + int query_idx = inserted_query[idx]; + int batch_idx__query_idx = batch_idx * num_query + query_idx; + + int slot = set_lookup(hashtable_query, hashtable_size, query_idx); + int duplicate_count = hashtable_count[slot]; + + float weight = 0; + for (int weight_idx_base = 0; weight_idx_base < weight_dim; weight_idx_base = weight_idx_base + WARP_SIZE) { + int weight_dim_idx = weight_idx_base + warp_thread_idx; + float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx]; + #pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) { + val += __shfl_xor_sync(FULL_MASK, val, offset); + } + weight = weight + val; + } + + weight = (float)duplicate_count * weight / float(num_hash_f); + + for (int value_idx_base = 0; value_idx_base < value_dim; value_idx_base = value_idx_base + WARP_SIZE) { + int value_dim_idx = value_idx_base + warp_thread_idx; + float val = value_buffer[value_dim_idx]; + atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val); + } + } + } + } else { + + // all computation is completed if num_distint_query == 0 + break; + + } + + __syncthreads(); + + } + +} diff --git a/src/transformers/models/yoso/fast_lsh_cumulation_cuda.h b/src/transformers/models/yoso/fast_lsh_cumulation_cuda.h new file mode 100644 index 00000000000..b2adc0f7353 --- /dev/null +++ b/src/transformers/models/yoso/fast_lsh_cumulation_cuda.h @@ -0,0 +1,157 @@ +__global__ void fast_hash_ver1_cuda_kernel( + int *mask, // [batch_size, num_vector] + float *vector, // [batch_size, num_vector, vector_dim] + int *Dmat, // [3, num_part, vector_dim] + int *hash_code, // [batch_size, num_vector, num_hash_f] + int batch_size, + int num_vector, + int vector_dim, + int num_part, + int num_hash_f, + int hash_code_len +); + +__global__ void lsh_cumulation_ver1_step1_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + float *value, // [batch_size, num_key, value_dim] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key, + int value_dim, + int offset_warp +); + +__global__ void lsh_cumulation_ver1_step2_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_hash_code, // [batch_size, num_query, num_hash_f] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_query, + int value_dim, + int offset_warp +); + +__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key, + int value_dim, + int weight_dim, + int offset_warp, + int weight_idx +); + +__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_hash_code, // [batch_size, num_query, num_hash_f] + float *query_weight, // [batch_size, num_query, weight_dim] + float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_query, + int value_dim, + int weight_dim, + int offset_warp, + int weight_idx +); + +__global__ void count_sort_step1_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key +); + +__global__ void count_sort_step2_cuda_kernel( + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int batch_size, + int num_hash_f, + int hashtable_capacity +); + +__global__ void count_sort_step3_cuda_kernel( + int *key_mask, // [batch_size, num_key] + int *key_hash_code, // [batch_size, num_key, num_hash_f] + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int *key_sorted_idxes, // [batch_size, num_hash_f, num_key] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_key +); + +__global__ void extract_query_info_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_hash_code, // [batch_size, num_query, num_hash_f] + int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity] + int *query_info, // [batch_size, num_query, 2, num_hash_f] + int batch_size, + int num_hash_f, + int hashtable_capacity, + int num_query +); + +__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel( + int *query_mask, // [batch_size, num_query] + int *query_info, // [batch_size, num_query, 2, num_hash_f] + int *key_sorted_idxes, // [batch_size, num_hash_f, num_key] + float *query_weight, // [batch_size, num_query, weight_dim] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int num_query, + int num_key, + int value_dim, + int weight_dim +); + +__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel( + int *query_sorted_idxes, // [batch_size, num_hash_f, num_query] + int *key_mask, // [batch_size, num_key] + int *key_info, // [batch_size, num_key, 2, num_hash_f] + float *query_weight, // [batch_size, num_query, weight_dim] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int num_query, + int num_key, + int value_dim, + int weight_dim +); + +__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel( + int *query_sorted_idxes, // [batch_size, num_hash_f, num_query] + int *key_mask, // [batch_size, num_key] + int *key_info, // [batch_size, num_key, 2, num_hash_f] + float *query_weight, // [batch_size, num_query, weight_dim] + float *key_weight, // [batch_size, num_key, weight_dim] + float *value, // [batch_size, num_key, value_dim] + float *cumulation_value, // [batch_size, num_query, value_dim] + int batch_size, + int num_hash_f, + int num_query, + int num_key, + int value_dim, + int weight_dim +); diff --git a/src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp b/src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp new file mode 100644 index 00000000000..e150a2be604 --- /dev/null +++ b/src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp @@ -0,0 +1,128 @@ +#include +#include +#include "fast_lsh_cumulation.h" +#include "common_cuda.h" +#include + +std::vector fast_hash( + at::Tensor query_mask, + at::Tensor query_vector, + at::Tensor key_mask, + at::Tensor key_vector, + int num_hash_f, + int hash_code_len, + bool use_cuda, + int version +) { + return fast_hash_ver1_kernel( + query_mask, + query_vector, + key_mask, + key_vector, + num_hash_f, + hash_code_len, + use_cuda + ); +} + +at::Tensor lsh_cumulation( + at::Tensor query_mask, // [batch_size, num_query] + at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f] + at::Tensor key_mask, // [batch_size, num_key] + at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f] + at::Tensor value, // [batch_size, num_key, value_dim] + int hashtable_capacity, + bool use_cuda, + int version +) { + return lsh_cumulation_ver1_kernel( + query_mask, + query_hash_code, + key_mask, + key_hash_code, + value, + hashtable_capacity, + use_cuda + ); +} + +at::Tensor lsh_weighted_cumulation( + at::Tensor query_mask, // [batch_size, num_query] + at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f] + at::Tensor query_weight, // [batch_size, num_query, weight_dim] + at::Tensor key_mask, // [batch_size, num_key] + at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f] + at::Tensor key_weight, // [batch_size, num_key, weight_dim] + at::Tensor value, // [batch_size, num_key, value_dim] + int hashtable_capacity, + bool use_cuda, + int version +) { + if (version == 1) { + return lsh_weighted_cumulation_ver1_kernel( + query_mask, + query_hash_code, + query_weight, + key_mask, + key_hash_code, + key_weight, + value, + hashtable_capacity, + use_cuda + ); + } else if (version == 2) { + return lsh_weighted_cumulation_ver2_kernel( + query_mask, + query_hash_code, + query_weight, + key_mask, + key_hash_code, + key_weight, + value, + hashtable_capacity, + use_cuda + ); + } else if (version == 3) { + return lsh_weighted_cumulation_ver3_kernel( + query_mask, + query_hash_code, + query_weight, + key_mask, + key_hash_code, + key_weight, + value, + hashtable_capacity, + use_cuda + ); + } else if (version == 4) { + return lsh_weighted_cumulation_ver4_kernel( + query_mask, + query_hash_code, + query_weight, + key_mask, + key_hash_code, + key_weight, + value, + hashtable_capacity, + use_cuda + ); + } else { + return lsh_weighted_cumulation_ver3_kernel( + query_mask, + query_hash_code, + query_weight, + key_mask, + key_hash_code, + key_weight, + value, + hashtable_capacity, + use_cuda + ); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)"); + m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)"); + m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)"); +} diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py new file mode 100644 index 00000000000..d9fd5def821 --- /dev/null +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -0,0 +1,1324 @@ +# coding=utf-8 +# Copyright 2022 University of Wisconsin-Madison and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch YOSO model.""" + + +import math +import os + +import torch +import torch.utils.checkpoint +from packaging import version +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from ...utils import logging +from .configuration_yoso import YosoConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "uw-madison/yoso-4096" +_CONFIG_FOR_DOC = "YosoConfig" +_TOKENIZER_FOR_DOC = "AutoTokenizer" + +YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "uw-madison/yoso-4096", + # See all YOSO models at https://huggingface.co/models?filter=yoso +] + + +def load_cuda_kernels(): + global lsh_cumulation + try: + from torch.utils.cpp_extension import load + + def append_root(files): + src_folder = os.path.dirname(os.path.realpath(__file__)) + return [os.path.join(src_folder, file) for file in files] + + src_files = append_root( + ["fast_lsh_cumulation_torch.cpp", "fast_lsh_cumulation.cu", "fast_lsh_cumulation_cuda.cu"] + ) + + load("fast_lsh_cumulation", src_files, verbose=True) + + import fast_lsh_cumulation as lsh_cumulation + + return True + except Exception: + lsh_cumulation = None + return False + + +def to_contiguous(input_tensors): + if isinstance(input_tensors, list): + out = [] + for tensor in input_tensors: + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + out.append(tensor) + return out + else: + if not input_tensors.is_contiguous(): + input_tensors = input_tensors.contiguous() + return input_tensors + + +def normalize(input_tensors): + if type(input_tensors) is list: + out = [] + for tensor in input_tensors: + out.append(nn.functional.normalize(tensor, p=2, dim=-1)) + return out + else: + return nn.functional.normalize(input_tensors, p=2, dim=-1) + + +def hashing(query, key, num_hash, hash_len): + + if len(query.size()) != 3: + raise ValueError("Query has incorrect size.") + if len(key.size()) != 3: + raise ValueError("Key has incorrect size.") + + rmat = torch.randn(query.size(0), query.size(2), num_hash * hash_len, device=query.device) + raise_pow = 2 ** torch.arange(hash_len, device=query.device) + + query_projection = torch.matmul(query, rmat).reshape(query.size(0), query.size(1), num_hash, hash_len) + key_projection = torch.matmul(key, rmat).reshape(key.size(0), key.size(1), num_hash, hash_len) + query_binary = (query_projection > 0).int() + key_binary = (key_projection > 0).int() + query_hash = torch.sum(query_binary * raise_pow, dim=-1) + query_hash = torch.sum(key_binary * raise_pow, dim=-1) + + return query_hash.int(), query_hash.int() + + +class YosoCumulation(torch.autograd.Function): + @staticmethod + def forward(ctx, query_mask, key_mask, query, key, value, config): + hash_code_len = config["hash_code_len"] + + expectation = (1 - torch.acos(torch.matmul(query, key.transpose(-1, -2))) / math.pi) ** hash_code_len + expectation = expectation * query_mask[:, :, None] * key_mask[:, None, :] + cumulation_value = torch.matmul(expectation, value) + + ctx.save_for_backward(query_mask, key_mask, expectation, query, key, value) + ctx.config = config + + return cumulation_value + + @staticmethod + def backward(ctx, grad): + grad = to_contiguous(grad) + + query_mask, key_mask, expectation, query, key, value = ctx.saved_tensors + config = ctx.config + + hash_code_len = config["hash_code_len"] + + weighted_exp = torch.matmul(grad, value.transpose(-1, -2)) * expectation + grad_query = torch.matmul(weighted_exp, (hash_code_len / 2) * key) + grad_key = torch.matmul(weighted_exp.transpose(-1, -2), (hash_code_len / 2) * query) + grad_value = torch.matmul(expectation.transpose(-1, -2), grad) + + return None, None, grad_query, grad_key, grad_value, None + + +class YosoLSHCumulation(torch.autograd.Function): + @staticmethod + def forward(ctx, query_mask, key_mask, query, key, value, config): + if query_mask.size(0) != key_mask.size(0): + raise ValueError("Query mask and Key mask differ in sizes in dimension 0") + if query_mask.size(0) != query.size(0): + raise ValueError("Query mask and Query differ in sizes in dimension 0") + if query_mask.size(0) != key.size(0): + raise ValueError("Query mask and Key differ in sizes in dimension 0") + if query_mask.size(0) != value.size(0): + raise ValueError("Query mask and Value mask differ in sizes in dimension 0") + if key.size(1) != value.size(1): + raise ValueError("Key and Value differ in sizes in dimension 1") + if query.size(2) != key.size(2): + raise ValueError("Query and Key differ in sizes in dimension 2") + + query_mask, key_mask, query, key, value = to_contiguous([query_mask, key_mask, query, key, value]) + + use_cuda = query_mask.is_cuda + num_hash = config["num_hash"] + hash_code_len = config["hash_code_len"] + hashtable_capacity = int(2 ** hash_code_len) + + if config["use_fast_hash"]: + query_hash_code, key_hash_code = lsh_cumulation.fast_hash( + query_mask, query, key_mask, key, num_hash, hash_code_len, use_cuda, 1 + ) + else: + query_hash_code, key_hash_code = hashing(query, key, num_hash, hash_code_len) + + cumulation_value = lsh_cumulation.lsh_cumulation( + query_mask, query_hash_code, key_mask, key_hash_code, value, hashtable_capacity, use_cuda, 1 + ) + + ctx.save_for_backward(query_mask, key_mask, query_hash_code, key_hash_code, query, key, value) + ctx.config = config + + return cumulation_value + + @staticmethod + def backward(ctx, grad): + grad = to_contiguous(grad) + + query_mask, key_mask, query_hash_code, key_hash_code, query, key, value = ctx.saved_tensors + config = ctx.config + + use_cuda = grad.is_cuda + hash_code_len = config["hash_code_len"] + hashtable_capacity = int(2 ** hash_code_len) + + if config["lsh_backward"]: + grad_value = lsh_cumulation.lsh_cumulation( + key_mask, key_hash_code, query_mask, query_hash_code, grad, hashtable_capacity, use_cuda, 1 + ) + grad_query = lsh_cumulation.lsh_weighted_cumulation( + query_mask, + query_hash_code, + grad, + key_mask, + key_hash_code, + value, + (hash_code_len / 2) * key, + hashtable_capacity, + use_cuda, + 4, + ) + grad_key = lsh_cumulation.lsh_weighted_cumulation( + key_mask, + key_hash_code, + value, + query_mask, + query_hash_code, + grad, + (hash_code_len / 2) * query, + hashtable_capacity, + use_cuda, + 4, + ) + else: + expectation = (1 - torch.acos(torch.matmul(query, key.transpose(-1, -2))) / math.pi) ** hash_code_len + expectation = expectation * query_mask[:, :, None] * key_mask[:, None, :] + weighted_exp = torch.matmul(grad, value.transpose(-1, -2)) * expectation + grad_query = torch.matmul(weighted_exp, (hash_code_len / 2) * key) + grad_key = torch.matmul(weighted_exp.transpose(-1, -2), (hash_code_len / 2) * query) + grad_value = torch.matmul(expectation.transpose(-1, -2), grad) + + return None, None, grad_query, grad_key, grad_value, None + + +# Copied from transformers.models.nystromformer.modeling_nystromformer.NystromformerEmbeddings +class YosoEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings + 2, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class YosoSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = ( + position_embedding_type if position_embedding_type is not None else config.position_embedding_type + ) + + self.use_expectation = config.use_expectation + self.hash_code_len = config.hash_code_len + self.use_conv = config.conv_window is not None + self.use_fast_hash = config.use_fast_hash + self.num_hash = config.num_hash + self.lsh_backward = config.lsh_backward + + self.lsh_config = { + "hash_code_len": self.hash_code_len, + "use_fast_hash": self.use_fast_hash, + "num_hash": self.num_hash, + "lsh_backward": self.lsh_backward, + } + + if config.conv_window is not None: + self.conv = nn.Conv2d( + in_channels=config.num_attention_heads, + out_channels=config.num_attention_heads, + kernel_size=(config.conv_window, 1), + padding=(config.conv_window // 2, 0), + bias=False, + groups=config.num_attention_heads, + ) + + def transpose_for_scores(self, layer): + new_layer_shape = layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + layer = layer.view(*new_layer_shape) + return layer.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.use_conv: + conv_value_layer = self.conv(value_layer * attention_mask[:, None, :, None]) + + batch_size, num_heads, seq_len, head_dim = query_layer.size() + + query_layer = query_layer.reshape(batch_size * num_heads, seq_len, head_dim) + key_layer = key_layer.reshape(batch_size * num_heads, seq_len, head_dim) + value_layer = value_layer.reshape(batch_size * num_heads, seq_len, head_dim) + + # revert changes made by get_extended_attention_mask + attention_mask = 1.0 + attention_mask / 10000.0 + attention_mask = ( + attention_mask.squeeze().repeat(1, num_heads, 1).reshape(batch_size * num_heads, seq_len).int() + ) + + # The CUDA kernels are most efficient with inputs whose size is a multiple of a GPU's warp size (32). Inputs + # smaller than this are padded with zeros. + gpu_warp_size = 32 + + if (not self.use_expectation) and head_dim < gpu_warp_size: + pad_size = batch_size * num_heads, seq_len, gpu_warp_size - head_dim + + query_layer = torch.cat( + [ + query_layer, + torch.zeros(pad_size, device=query_layer.device), + ], + dim=-1, + ) + key_layer = torch.cat( + [ + key_layer, + torch.zeros(pad_size, device=key_layer.device), + ], + dim=-1, + ) + value_layer = torch.cat( + [ + value_layer, + torch.zeros(pad_size, device=value_layer.device), + ], + dim=-1, + ) + + if self.use_expectation or self.training: + query_layer, key_layer = normalize([query_layer, key_layer]) + + if self.use_expectation: + context_layer = YosoCumulation.apply( + attention_mask, attention_mask, query_layer, key_layer, value_layer, self.lsh_config + ) + else: + context_layer = YosoLSHCumulation.apply( + attention_mask, attention_mask, query_layer, key_layer, value_layer, self.lsh_config + ) + + if (not self.use_expectation) and head_dim < gpu_warp_size: + context_layer = context_layer[:, :, :head_dim] + + context_layer = normalize(context_layer) + + context_layer = context_layer.reshape(batch_size, num_heads, seq_len, head_dim) + + if self.use_conv: + context_layer += conv_value_layer + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, context_layer) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class YosoSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class YosoAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = YosoSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = YosoSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + self_outputs = self.self(hidden_states, attention_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class YosoIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class YosoOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class YosoLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = YosoAttention(config) + self.add_cross_attention = config.add_cross_attention + self.intermediate = YosoIntermediate(config) + self.output = YosoOutput(config) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class YosoEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([YosoLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform +class YosoPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Yoso +class YosoLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = YosoPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Yoso +class YosoOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = YosoLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class YosoPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = YosoConfig + base_model_prefix = "yoso" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, YosoEncoder): + module.gradient_checkpointing = value + + +YOSO_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`YosoConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +YOSO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare YOSO Model transformer outputting raw hidden-states without any specific head on top.", + YOSO_START_DOCSTRING, +) +class YosoModel(YosoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = YosoEmbeddings(config) + self.encoder = YosoEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithCrossAttentions( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings("""YOSO Model with a `language modeling` head on top.""", YOSO_START_DOCSTRING) +class YosoForMaskedLM(YosoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.yoso = YosoModel(config) + self.cls = YosoOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.yoso( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class YosoClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """YOSO Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks.""", + YOSO_START_DOCSTRING, +) +class YosoForSequenceClassification(YosoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.yoso = YosoModel(config) + self.classifier = YosoClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.yoso( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """YOSO Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""", + YOSO_START_DOCSTRING, +) +class YosoForMultipleChoice(YosoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.yoso = YosoModel(config) + self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.yoso( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """YOSO Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""", + YOSO_START_DOCSTRING, +) +class YosoForTokenClassification(YosoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.yoso = YosoModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.yoso( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """YOSO Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""", + YOSO_START_DOCSTRING, +) +class YosoForQuestionAnswering(YosoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + config.num_labels = 2 + self.num_labels = config.num_labels + + self.yoso = YosoModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.yoso( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 58302460d18..4e05f4a522a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4060,6 +4060,65 @@ def load_tf_weights_in_xlnet(*args, **kwargs): requires_backends(load_tf_weights_in_xlnet, ["torch"]) +YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class YosoForMaskedLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoForMultipleChoice(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoForQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoForTokenClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoLayer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class YosoPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Adafactor(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/test_modeling_yoso.py b/tests/test_modeling_yoso.py new file mode 100644 index 00000000000..9cd00856cec --- /dev/null +++ b/tests/test_modeling_yoso.py @@ -0,0 +1,401 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch YOSO model. """ + + +import unittest + +from tests.test_modeling_common import floats_tensor +from transformers import YosoConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +if is_torch_available(): + import torch + + from transformers import ( + YosoForMaskedLM, + YosoForMultipleChoice, + YosoForQuestionAnswering, + YosoForSequenceClassification, + YosoForTokenClassification, + YosoModel, + ) + from transformers.models.yoso.modeling_yoso import YOSO_PRETRAINED_MODEL_ARCHIVE_LIST + + +class YosoModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return YosoConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = YosoModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = YosoModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = YosoForMaskedLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = YosoForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = YosoForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def create_and_check_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = YosoForTokenClassification(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) + + def create_and_check_for_multiple_choice( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = YosoForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, + ) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class YosoModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + YosoModel, + YosoForMaskedLM, + YosoForMultipleChoice, + YosoForQuestionAnswering, + YosoForSequenceClassification, + YosoForTokenClassification, + ) + if is_torch_available() + else () + ) + test_pruning = False + test_headmasking = False + test_torchscript = False + + all_generative_model_classes = () + + def setUp(self): + self.model_tester = YosoModelTester(self) + self.config_tester = ConfigTester(self, config_class=YosoConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in YOSO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = YosoModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + def test_attention_outputs(self): + return + + +@require_torch +class YosoModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head(self): + model = YosoModel.from_pretrained("uw-madison/yoso-4096") + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) + + with torch.no_grad(): + output = model(input_ids)[0] + + expected_shape = torch.Size((1, 6, 768)) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [[[-0.0611, 0.1242, 0.0840], [0.0280, -0.0048, 0.1125], [0.0106, 0.0226, 0.0751]]] + ) + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_masked_lm(self): + model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096") + input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]]) + + with torch.no_grad(): + output = model(input_ids)[0] + + vocab_size = 50265 + + expected_shape = torch.Size((1, 6, vocab_size)) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [[[-2.1313, -3.7285, -2.2407], [-2.7047, -3.3314, -2.6408], [0.0629, -2.5166, -0.3356]]] + ) + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_inference_masked_lm_long_input(self): + model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096") + input_ids = torch.arange(4096).unsqueeze(0) + + with torch.no_grad(): + output = model(input_ids)[0] + + vocab_size = 50265 + + expected_shape = torch.Size((1, 4096, vocab_size)) + self.assertEqual(output.shape, expected_shape) + + expected_slice = torch.tensor( + [[[-2.3914, -4.3742, -5.0956], [-4.0988, -4.2384, -7.0406], [-3.1427, -3.7192, -6.6800]]] + ) + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))