mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add YOSO (#15091)
* Add cookiecutter files * Add cuda kernels and cpp files * Update modeling_yoso.py * Add .h files * Update configuration_yoso.py * Updates * Remove tokenizer * Code quality * Update modeling_yoso.py * Update modeling_yoso.py * Fix failing test * Update modeling_yoso.py * Fix code quality * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review and fix integration tests * Update src/transformers/models/yoso/modeling_yoso.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply suggestions from code review * Fix copied from statement * Fix docstring * Fix code quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions and fix mask * Apply suggestions from code review * Fix code quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix docstrings * Fix code quality * Remove trailing whitespace * Update yoso.mdx * Move kernel loading to YosoEncoder * make style * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/yoso/modeling_yoso.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Add short summary to docs * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update yoso.mdx * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Remove CausalLM model and add copied from * Remove autoregressive code * Remove unused imports * add copied from for embeddings * Fix code quality * Update docs/source/model_doc/yoso.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestion from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
6292532fd1
commit
99a2771189
@ -325,6 +325,7 @@ AWARE PRE-TRAINING](https://arxiv.org/abs/2110.05752) by Sanyuan Chen, Yu Wu, Ch
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.
|
||||
|
||||
To check if each model has an implementation in Flax, PyTorch or TensorFlow, or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/docs/transformers/index#supported-frameworks).
|
||||
|
@ -303,6 +303,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
1. 새로운 모델을 올리고 싶나요? 우리가 **상세한 가이드와 템플릿** 으로 새로운 모델을 올리도록 도와드릴게요. 가이드와 템플릿은 이 저장소의 [`templates`](./templates) 폴더에서 확인하실 수 있습니다. [컨트리뷰션 가이드라인](./CONTRIBUTING.md)을 꼭 확인해주시고, PR을 올리기 전에 메인테이너에게 연락하거나 이슈를 오픈해 피드백을 받으시길 바랍니다.
|
||||
|
||||
각 모델이 Flax, PyTorch, TensorFlow으로 구현되었는지 또는 🤗 Tokenizers 라이브러리가 지원하는 토크나이저를 사용하는지 확인하려면, [이 표](https://huggingface.co/docs/transformers/index#supported-frameworks)를 확인하세요.
|
||||
|
@ -327,6 +327,7 @@ conda install -c huggingface transformers
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (来自 Google/CMU) 伴随论文 [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) 由 Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le 发布。
|
||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (来自 Facebook AI) 伴随论文 [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) 由 Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli 发布。
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (来自 Facebook AI) 伴随论文 [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) 由 Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli 发布。
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (来自 the University of Wisconsin - Madison) 伴随论文 [You Only Sample (Almost) 由 Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh 发布。
|
||||
1. 想要贡献新的模型?我们这里有一份**详细指引和模板**来引导你添加新的模型。你可以在 [`templates`](./templates) 目录中找到他们。记得查看 [贡献指南](./CONTRIBUTING.md) 并在开始写 PR 前联系维护人员或开一个新的 issue 来获得反馈。
|
||||
|
||||
要检查某个模型是否已有 Flax、PyTorch 或 TensorFlow 的实现,或其是否在 🤗 Tokenizers 库中有对应词符化器(tokenizer),敬请参阅[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||
|
@ -339,6 +339,7 @@ conda install -c huggingface transformers
|
||||
1. **[XLNet](https://huggingface.co/docs/transformers/model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[XLSR-Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[YOSO](https://huggingface.co/docs/transformers/master/model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
1. 想要貢獻新的模型?我們這裡有一份**詳細指引和模板**來引導你加入新的模型。你可以在 [`templates`](./templates) 目錄中找到它們。記得查看[貢獻指引](./CONTRIBUTING.md)並在開始寫 PR 前聯繫維護人員或開一個新的 issue 來獲得 feedbacks。
|
||||
|
||||
要檢查某個模型是否已有 Flax、PyTorch 或 TensorFlow 的實作,或其是否在🤗 Tokenizers 函式庫中有對應的 tokenizer,敬請參閱[此表](https://huggingface.co/docs/transformers/index#supported-frameworks)。
|
||||
|
@ -316,6 +316,8 @@
|
||||
title: XLSR-Wav2Vec2
|
||||
- local: model_doc/xls_r
|
||||
title: XLS-R
|
||||
- local: model_doc/yoso
|
||||
title: YOSO
|
||||
title: Models
|
||||
- sections:
|
||||
- local: internal/modeling_utils
|
||||
|
@ -184,6 +184,7 @@ conversion utilities for the following models.
|
||||
1. **[XLNet](model_doc/xlnet)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
1. **[XLSR-Wav2Vec2](model_doc/xlsr_wav2vec2)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
|
||||
1. **[XLS-R](https://huggingface.co/docs/master/transformers/model_doc/xls_r)** (from Facebook AI) released with the paper [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale](https://arxiv.org/abs/2111.09296) by Arun Babu, Changhan Wang, Andros Tjandra, Kushal Lakhotia, Qiantong Xu, Naman Goyal, Kritika Singh, Patrick von Platen, Yatharth Saraf, Juan Pino, Alexei Baevski, Alexis Conneau, Michael Auli.
|
||||
1. **[YOSO](model_doc/yoso)** (from the University of Wisconsin - Madison) released with the paper [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714) by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh.
|
||||
|
||||
|
||||
### Supported frameworks
|
||||
@ -281,5 +282,6 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| YOSO | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
||||
<!-- End table-->
|
||||
|
91
docs/source/model_doc/yoso.mdx
Normal file
91
docs/source/model_doc/yoso.mdx
Normal file
@ -0,0 +1,91 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# YOSO
|
||||
|
||||
## Overview
|
||||
|
||||
The YOSO model was proposed in [You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling](https://arxiv.org/abs/2111.09714)
|
||||
by Zhanpeng Zeng, Yunyang Xiong, Sathya N. Ravi, Shailesh Acharya, Glenn Fung, Vikas Singh. YOSO approximates standard softmax self-attention
|
||||
via a Bernoulli sampling scheme based on Locality Sensitive Hashing (LSH). In principle, all the Bernoulli random variables can be sampled with
|
||||
a single hash.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformer-based models are widely used in natural language processing (NLP). Central to the transformer model is
|
||||
the self-attention mechanism, which captures the interactions of token pairs in the input sequences and depends quadratically
|
||||
on the sequence length. Training such models on longer sequences is expensive. In this paper, we show that a Bernoulli sampling
|
||||
attention mechanism based on Locality Sensitive Hashing (LSH), decreases the quadratic complexity of such models to linear.
|
||||
We bypass the quadratic cost by considering self-attention as a sum of individual tokens associated with Bernoulli random
|
||||
variables that can, in principle, be sampled at once by a single hash (although in practice, this number may be a small constant).
|
||||
This leads to an efficient sampling scheme to estimate self-attention which relies on specific modifications of
|
||||
LSH (to enable deployment on GPU architectures). We evaluate our algorithm on the GLUE benchmark with standard 512 sequence
|
||||
length where we see favorable performance relative to a standard pretrained Transformer. On the Long Range Arena (LRA) benchmark,
|
||||
for evaluating performance on long sequences, our method achieves results consistent with softmax self-attention but with sizable
|
||||
speed-ups and memory savings and often outperforms other efficient self-attention methods. Our code is available at this https URL*
|
||||
|
||||
Tips:
|
||||
|
||||
- The YOSO attention algorithm is implemented through custom CUDA kernels, functions written in CUDA C++ that can be executed multiple times
|
||||
in parallel on a GPU.
|
||||
- The kernels provide a `fast_hash` function, which approximates the random projections of the queries and keys using the Fast Hadamard Transform. Using these
|
||||
hash codes, the `lsh_cumulation` function approximates self-attention via LSH-based Bernoulli sampling.
|
||||
- To use the custom kernels, the user should set `config.use_expectation = False`. To ensure that the kernels are compiled successfully,
|
||||
the user must install the correct version of PyTorch and cudatoolkit. By default, `config.use_expectation = True`, which uses YOSO-E and
|
||||
does not require compiling CUDA kernels.
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/yoso_architecture.jpg"
|
||||
alt="drawing" width="600"/>
|
||||
|
||||
<small> YOSO Attention Algorithm. Taken from the <a href="https://arxiv.org/abs/2111.09714">original paper</a>.</small>
|
||||
|
||||
This model was contributed by [novice03](https://huggingface.co/novice03). The original code can be found [here](https://github.com/mlpen/YOSO).
|
||||
|
||||
|
||||
## YosoConfig
|
||||
|
||||
[[autodoc]] YosoConfig
|
||||
|
||||
|
||||
## YosoModel
|
||||
|
||||
[[autodoc]] YosoModel
|
||||
- forward
|
||||
|
||||
|
||||
## YosoForMaskedLM
|
||||
|
||||
[[autodoc]] YosoForMaskedLM
|
||||
- forward
|
||||
|
||||
|
||||
## YosoForSequenceClassification
|
||||
|
||||
[[autodoc]] YosoForSequenceClassification
|
||||
- forward
|
||||
|
||||
## YosoForMultipleChoice
|
||||
|
||||
[[autodoc]] YosoForMultipleChoice
|
||||
- forward
|
||||
|
||||
|
||||
## YosoForTokenClassification
|
||||
|
||||
[[autodoc]] YosoForTokenClassification
|
||||
- forward
|
||||
|
||||
|
||||
## YosoForQuestionAnswering
|
||||
|
||||
[[autodoc]] YosoForQuestionAnswering
|
||||
- forward
|
@ -333,6 +333,7 @@ _import_structure = {
|
||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||
"models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
|
||||
"models.yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
|
||||
"onnx": [],
|
||||
"pipelines": [
|
||||
"AudioClassificationPipeline",
|
||||
@ -1510,6 +1511,19 @@ if is_torch_available():
|
||||
"load_tf_weights_in_xlnet",
|
||||
]
|
||||
)
|
||||
_import_structure["models.yoso"].extend(
|
||||
[
|
||||
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"YosoForMaskedLM",
|
||||
"YosoForMultipleChoice",
|
||||
"YosoForQuestionAnswering",
|
||||
"YosoForSequenceClassification",
|
||||
"YosoForTokenClassification",
|
||||
"YosoLayer",
|
||||
"YosoModel",
|
||||
"YosoPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
"Adafactor",
|
||||
"AdamW",
|
||||
@ -2454,6 +2468,7 @@ if TYPE_CHECKING:
|
||||
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||
from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||
from .models.yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
|
||||
|
||||
# Pipelines
|
||||
from .pipelines import (
|
||||
@ -3431,6 +3446,17 @@ if TYPE_CHECKING:
|
||||
XLNetPreTrainedModel,
|
||||
load_tf_weights_in_xlnet,
|
||||
)
|
||||
from .models.yoso import (
|
||||
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
YosoForMaskedLM,
|
||||
YosoForMultipleChoice,
|
||||
YosoForQuestionAnswering,
|
||||
YosoForSequenceClassification,
|
||||
YosoForTokenClassification,
|
||||
YosoLayer,
|
||||
YosoModel,
|
||||
YosoPreTrainedModel,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
from .optimization import (
|
||||
|
@ -119,4 +119,5 @@ from . import (
|
||||
xlm_prophetnet,
|
||||
xlm_roberta,
|
||||
xlnet,
|
||||
yoso,
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ logger = logging.get_logger(__name__)
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("yoso", "YosoConfig"),
|
||||
("swin", "SwinConfig"),
|
||||
("vilt", "ViltConfig"),
|
||||
("vit_mae", "ViTMAEConfig"),
|
||||
@ -121,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add archive maps here
|
||||
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -200,6 +202,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("yoso", "YOSO"),
|
||||
("swin", "Swin"),
|
||||
("vilt", "ViLT"),
|
||||
("vit_mae", "ViTMAE"),
|
||||
|
@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
|
||||
MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("yoso", "YosoModel"),
|
||||
("swin", "SwinModel"),
|
||||
("vilt", "ViltModel"),
|
||||
("vit_mae", "ViTMAEModel"),
|
||||
@ -155,6 +156,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
("yoso", "YosoForMaskedLM"),
|
||||
("nystromformer", "NystromformerForMaskedLM"),
|
||||
("qdqbert", "QDQBertForMaskedLM"),
|
||||
("fnet", "FNetForMaskedLM"),
|
||||
@ -284,6 +286,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Masked LM mapping
|
||||
("yoso", "YosoForMaskedLM"),
|
||||
("nystromformer", "NystromformerForMaskedLM"),
|
||||
("perceiver", "PerceiverForMaskedLM"),
|
||||
("qdqbert", "QDQBertForMaskedLM"),
|
||||
@ -357,6 +360,7 @@ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
("yoso", "YosoForSequenceClassification"),
|
||||
("nystromformer", "NystromformerForSequenceClassification"),
|
||||
("perceiver", "PerceiverForSequenceClassification"),
|
||||
("qdqbert", "QDQBertForSequenceClassification"),
|
||||
@ -405,6 +409,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Question Answering mapping
|
||||
("yoso", "YosoForQuestionAnswering"),
|
||||
("nystromformer", "NystromformerForQuestionAnswering"),
|
||||
("qdqbert", "QDQBertForQuestionAnswering"),
|
||||
("fnet", "FNetForQuestionAnswering"),
|
||||
@ -454,6 +459,7 @@ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Token Classification mapping
|
||||
("yoso", "YosoForTokenClassification"),
|
||||
("nystromformer", "NystromformerForTokenClassification"),
|
||||
("qdqbert", "QDQBertForTokenClassification"),
|
||||
("fnet", "FNetForTokenClassification"),
|
||||
@ -490,6 +496,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Multiple Choice mapping
|
||||
("yoso", "YosoForMultipleChoice"),
|
||||
("nystromformer", "NystromformerForMultipleChoice"),
|
||||
("qdqbert", "QDQBertForMultipleChoice"),
|
||||
("fnet", "FNetForMultipleChoice"),
|
||||
|
62
src/transformers/models/yoso/__init__.py
Normal file
62
src/transformers/models/yoso/__init__.py
Normal file
@ -0,0 +1,62 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# rely on isort to merge the imports
|
||||
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_yoso"] = [
|
||||
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"YosoForMaskedLM",
|
||||
"YosoForMultipleChoice",
|
||||
"YosoForQuestionAnswering",
|
||||
"YosoForSequenceClassification",
|
||||
"YosoForTokenClassification",
|
||||
"YosoLayer",
|
||||
"YosoModel",
|
||||
"YosoPreTrainedModel",
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_yoso import (
|
||||
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
YosoForMaskedLM,
|
||||
YosoForMultipleChoice,
|
||||
YosoForQuestionAnswering,
|
||||
YosoForSequenceClassification,
|
||||
YosoForTokenClassification,
|
||||
YosoLayer,
|
||||
YosoModel,
|
||||
YosoPreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
10
src/transformers/models/yoso/common.h
Normal file
10
src/transformers/models/yoso/common.h
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
#define min(a, b) ((a)<(b)?(a):(b))
|
||||
#define max(a, b) ((a)>(b)?(a):(b))
|
||||
#define ceil_divide(a, b) ((a)/(b)+((a)%(b)!=0))
|
||||
#define select(cond, a, b) ((cond)?(a):(b))
|
||||
#define PI 3.141592
|
||||
#define EPSILON 1e-8
|
||||
#define MAX_VAL 1e12
|
||||
#define MIN_VAL -1e12
|
||||
#define EMPTY_VALUE -1
|
9
src/transformers/models/yoso/common_cuda.h
Normal file
9
src/transformers/models/yoso/common_cuda.h
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
#define MAX_THREADS_PER_BLOCK 1024
|
||||
#define OPTIMAL_THREADS_PER_BLOCK 256
|
||||
#define WARP_SIZE 32
|
||||
#define MAX_NUM_BLOCK_X 2147483647
|
||||
#define MAX_NUM_BLOCK_Y 65535
|
||||
#define MAX_NUM_BLOCK_Z 65535
|
||||
#define MAX_SHARED_MEM_PER_BLOCK 48000
|
||||
#define FULL_MASK 0xffffffff
|
79
src/transformers/models/yoso/common_cuda_device.h
Normal file
79
src/transformers/models/yoso/common_cuda_device.h
Normal file
@ -0,0 +1,79 @@
|
||||
|
||||
#include "common.h"
|
||||
|
||||
template<typename T>
|
||||
__device__ int set_insert(T *set, int set_size, T value) {
|
||||
int slot = value % set_size;
|
||||
int start_slot = slot;
|
||||
while (true) {
|
||||
T prev = atomicCAS(&set[slot], EMPTY_VALUE, value);
|
||||
if (prev == EMPTY_VALUE || prev == value) {
|
||||
return slot;
|
||||
}
|
||||
slot = (slot + 1) % set_size;
|
||||
if (slot == start_slot) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ int set_lookup(T *set, int set_size, T value) {
|
||||
int slot = value % set_size;
|
||||
int start_slot = slot;
|
||||
while (true) {
|
||||
if (set[slot] == value) {
|
||||
return slot;
|
||||
}
|
||||
slot = (slot + 1) % set_size;
|
||||
if (slot == start_slot) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ void init_buffer(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
|
||||
__syncthreads();
|
||||
for (int i = 0; i < buffer_size; i = i + num_threads) {
|
||||
int offset_idx = i + thread_id;
|
||||
if (offset_idx < buffer_size) {
|
||||
buffer[offset_idx] = init_value;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ void copy_data(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
|
||||
__syncthreads();
|
||||
for (int i = 0; i < data_length; i = i + num_threads) {
|
||||
int offset_idx = i + thread_id;
|
||||
if (offset_idx < data_length) {
|
||||
dist_pt[offset_idx] = src_pt[offset_idx];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ void init_buffer_nonblocking(T init_value, T *buffer, int buffer_size, int num_threads, int thread_id) {
|
||||
for (int i = 0; i < buffer_size; i = i + num_threads) {
|
||||
int offset_idx = i + thread_id;
|
||||
if (offset_idx < buffer_size) {
|
||||
buffer[offset_idx] = init_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ void copy_data_nonblocking(T *src_pt, T *dist_pt, int data_length, int num_threads, int thread_id) {
|
||||
for (int i = 0; i < data_length; i = i + num_threads) {
|
||||
int offset_idx = i + thread_id;
|
||||
if (offset_idx < data_length) {
|
||||
dist_pt[offset_idx] = src_pt[offset_idx];
|
||||
}
|
||||
}
|
||||
}
|
145
src/transformers/models/yoso/configuration_yoso.py
Normal file
145
src/transformers/models/yoso/configuration_yoso.py
Normal file
@ -0,0 +1,145 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" YOSO model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"uw-madison/yoso-4096": "https://huggingface.co/uw-madison/yoso-4096/resolve/main/config.json",
|
||||
# See all YOSO models at https://huggingface.co/models?filter=yoso
|
||||
}
|
||||
|
||||
|
||||
class YosoConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`YosoModel`]. It is used to instantiate an YOSO
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the YOSO
|
||||
[uw-madison/yoso-4096](https://huggingface.co/uw-madison/yoso-4096) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 50265):
|
||||
Vocabulary size of the YOSO model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`YosoModel`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (`int`, *optional*, defaults to 2):
|
||||
The vocabulary size of the `token_type_ids` passed when calling [`YosoModel`].
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
||||
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`.
|
||||
use_expectation (*bool*, *optional*, defaults to *True*):
|
||||
Whether or not to use YOSO Expectation. Overrides any effect of num_hash.
|
||||
hash_code_len (`int`, *optional*, defaults to 9):
|
||||
The length of hashes generated by the hash functions.
|
||||
num_hash (`int`, *optional*, defaults to 64):
|
||||
Number of hash functions used in [`YosoSelfAttention`].
|
||||
conv_window (`int`, *optional*, defaults to None):
|
||||
Kernel size of depth-wise convolution.
|
||||
use_fast_hash (*bool*, *optional*, defaults to *False*):
|
||||
Whether or not to use custom cuda kernels which perform fast random projection via hadamard transform.
|
||||
lsh_backward (*bool*, *optional*, defaults to *True*):
|
||||
Whether or not to perform backpropagation using Locality Sensitive Hashing.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import YosoModel, YosoConfig
|
||||
|
||||
>>> # Initializing a YOSO uw-madison/yoso-4096 style configuration
|
||||
>>> configuration = YosoConfig()
|
||||
|
||||
>>> # Initializing a model from the uw-madison/yoso-4096 style configuration
|
||||
>>> model = YosoModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "yoso"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50265,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=4096,
|
||||
type_vocab_size=1,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
position_embedding_type="absolute",
|
||||
use_expectation=True,
|
||||
hash_code_len=9,
|
||||
num_hash=64,
|
||||
conv_window=None,
|
||||
use_fast_hash=True,
|
||||
lsh_backward=True,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.use_expectation = use_expectation
|
||||
self.hash_code_len = hash_code_len
|
||||
self.num_hash = num_hash
|
||||
self.conv_window = conv_window
|
||||
self.use_fast_hash = use_fast_hash
|
||||
self.lsh_backward = lsh_backward
|
109
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
Normal file
109
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
Normal file
@ -0,0 +1,109 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert YOSO checkpoints from the original repository. URL: https://github.com/mlpen/YOSO"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import YosoConfig, YosoForMaskedLM
|
||||
|
||||
|
||||
def rename_key(orig_key):
|
||||
if "model" in orig_key:
|
||||
orig_key = orig_key.replace("model.", "")
|
||||
if "norm1" in orig_key:
|
||||
orig_key = orig_key.replace("norm1", "attention.output.LayerNorm")
|
||||
if "norm2" in orig_key:
|
||||
orig_key = orig_key.replace("norm2", "output.LayerNorm")
|
||||
if "norm" in orig_key:
|
||||
orig_key = orig_key.replace("norm", "LayerNorm")
|
||||
if "transformer" in orig_key:
|
||||
layer_num = orig_key.split(".")[0].split("_")[-1]
|
||||
orig_key = orig_key.replace(f"transformer_{layer_num}", f"encoder.layer.{layer_num}")
|
||||
if "mha.attn" in orig_key:
|
||||
orig_key = orig_key.replace("mha.attn", "attention.self")
|
||||
if "mha" in orig_key:
|
||||
orig_key = orig_key.replace("mha", "attention")
|
||||
if "W_q" in orig_key:
|
||||
orig_key = orig_key.replace("W_q", "self.query")
|
||||
if "W_k" in orig_key:
|
||||
orig_key = orig_key.replace("W_k", "self.key")
|
||||
if "W_v" in orig_key:
|
||||
orig_key = orig_key.replace("W_v", "self.value")
|
||||
if "ff1" in orig_key:
|
||||
orig_key = orig_key.replace("ff1", "intermediate.dense")
|
||||
if "ff2" in orig_key:
|
||||
orig_key = orig_key.replace("ff2", "output.dense")
|
||||
if "ff" in orig_key:
|
||||
orig_key = orig_key.replace("ff", "output.dense")
|
||||
if "mlm_class" in orig_key:
|
||||
orig_key = orig_key.replace("mlm.mlm_class", "cls.predictions.decoder")
|
||||
if "mlm" in orig_key:
|
||||
orig_key = orig_key.replace("mlm", "cls.predictions.transform")
|
||||
if "cls" not in orig_key:
|
||||
orig_key = "yoso." + orig_key
|
||||
|
||||
return orig_key
|
||||
|
||||
|
||||
def convert_checkpoint_helper(max_position_embeddings, orig_state_dict):
|
||||
for key in orig_state_dict.copy().keys():
|
||||
val = orig_state_dict.pop(key)
|
||||
|
||||
if ("pooler" in key) or ("sen_class" in key):
|
||||
continue
|
||||
else:
|
||||
orig_state_dict[rename_key(key)] = val
|
||||
|
||||
orig_state_dict["cls.predictions.bias"] = orig_state_dict["cls.predictions.decoder.bias"]
|
||||
orig_state_dict["yoso.embeddings.position_ids"] = torch.arange(max_position_embeddings).expand((1, -1)) + 2
|
||||
|
||||
return orig_state_dict
|
||||
|
||||
|
||||
def convert_yoso_checkpoint(checkpoint_path, yoso_config_file, pytorch_dump_path):
|
||||
|
||||
orig_state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
||||
config = YosoConfig.from_json_file(yoso_config_file)
|
||||
model = YosoForMaskedLM(config)
|
||||
|
||||
new_state_dict = convert_checkpoint_helper(config.max_position_embeddings, orig_state_dict)
|
||||
|
||||
print(model.load_state_dict(new_state_dict))
|
||||
model.eval()
|
||||
model.save_pretrained(pytorch_dump_path)
|
||||
|
||||
print(f"Checkpoint successfuly converted. Model saved at {pytorch_dump_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--pytorch_model_path", default=None, type=str, required=True, help="Path to YOSO pytorch checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The json file for YOSO model config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_yoso_checkpoint(args.pytorch_model_path, args.config_file, args.pytorch_dump_path)
|
588
src/transformers/models/yoso/fast_lsh_cumulation.cu
Normal file
588
src/transformers/models/yoso/fast_lsh_cumulation.cu
Normal file
@ -0,0 +1,588 @@
|
||||
// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation.cu
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include "fast_lsh_cumulation.h"
|
||||
#include "fast_lsh_cumulation_cuda.h"
|
||||
#include "common_cuda.h"
|
||||
#include "common.h"
|
||||
#include <vector>
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> fast_hash_ver1_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_vector,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_vector,
|
||||
int num_hash_f,
|
||||
int hash_code_len,
|
||||
bool use_cuda
|
||||
) {
|
||||
|
||||
int batch_size = query_vector.size(0);
|
||||
int num_query = query_vector.size(1);
|
||||
int num_key = key_vector.size(1);
|
||||
int vector_dim = query_vector.size(2);
|
||||
|
||||
int num_hash_per_part = vector_dim / hash_code_len;
|
||||
int num_part = max(1, ceil_divide(num_hash_f, num_hash_per_part));
|
||||
|
||||
at::Tensor Dmat = 2 * at::randint(0, 2, {batch_size, 3, num_part, vector_dim}, query_mask.options()) - 1;
|
||||
at::Tensor query_hash_code = at::zeros({batch_size, num_query, num_hash_f}, query_mask.options());
|
||||
at::Tensor key_hash_code = at::zeros({batch_size, num_key, num_hash_f}, key_mask.options());
|
||||
|
||||
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||
float *query_vector_ptr = query_vector.data_ptr<float>();
|
||||
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||
float *key_vector_ptr = key_vector.data_ptr<float>();
|
||||
|
||||
int *Dmat_ptr = Dmat.data_ptr<int>();
|
||||
|
||||
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||
|
||||
if (use_cuda) {
|
||||
{
|
||||
dim3 threads(vector_dim);
|
||||
dim3 blocks(num_part, num_query, batch_size);
|
||||
int shared_mem = vector_dim * sizeof(float);
|
||||
fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||
query_mask_ptr,
|
||||
query_vector_ptr,
|
||||
Dmat_ptr,
|
||||
query_hash_code_ptr,
|
||||
batch_size,
|
||||
num_query,
|
||||
vector_dim,
|
||||
num_part,
|
||||
num_hash_f,
|
||||
hash_code_len
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(vector_dim);
|
||||
dim3 blocks(num_part, num_key, batch_size);
|
||||
int shared_mem = vector_dim * sizeof(float);
|
||||
fast_hash_ver1_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||
key_mask_ptr,
|
||||
key_vector_ptr,
|
||||
Dmat_ptr,
|
||||
key_hash_code_ptr,
|
||||
batch_size,
|
||||
num_key,
|
||||
vector_dim,
|
||||
num_part,
|
||||
num_hash_f,
|
||||
hash_code_len
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {query_hash_code, key_hash_code};
|
||||
|
||||
}
|
||||
|
||||
at::Tensor lsh_cumulation_ver1_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
) {
|
||||
|
||||
int batch_size = query_hash_code.size(0);
|
||||
int num_hash_f = query_hash_code.size(2);
|
||||
|
||||
int num_query = query_hash_code.size(1);
|
||||
int num_key = key_hash_code.size(1);
|
||||
int value_dim = value.size(2);
|
||||
|
||||
at::Tensor hashtable_value = at::empty({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
|
||||
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||
|
||||
if (use_cuda) {
|
||||
int threads_x = WARP_SIZE;
|
||||
int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
|
||||
int block_x_step1 = num_key / threads_y;
|
||||
int block_x_step2 = num_query / threads_y;
|
||||
int block_y = batch_size;
|
||||
|
||||
dim3 threads(threads_x, threads_y);
|
||||
dim3 blocks_step1(block_x_step1, block_y);
|
||||
dim3 blocks_step2(block_x_step2, block_y);
|
||||
|
||||
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||
float *value_ptr = value.data_ptr<float>();
|
||||
float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
|
||||
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||
|
||||
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||
|
||||
cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
|
||||
|
||||
lsh_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
|
||||
key_mask_ptr,
|
||||
key_hash_code_ptr,
|
||||
value_ptr,
|
||||
hashtable_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_key,
|
||||
value_dim,
|
||||
value_offset
|
||||
);
|
||||
|
||||
lsh_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
hashtable_value_ptr,
|
||||
cumulation_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query,
|
||||
value_dim,
|
||||
value_offset
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return cumulation_value;
|
||||
|
||||
}
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver1_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
) {
|
||||
|
||||
int batch_size = query_hash_code.size(0);
|
||||
int num_hash_f = query_hash_code.size(2);
|
||||
|
||||
int num_query = query_hash_code.size(1);
|
||||
int num_key = key_hash_code.size(1);
|
||||
int value_dim = value.size(2);
|
||||
int weight_dim = query_weight.size(2);
|
||||
|
||||
at::Tensor hashtable_value = at::zeros({batch_size, num_hash_f, hashtable_capacity, WARP_SIZE}, value.options());
|
||||
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||
|
||||
if (use_cuda) {
|
||||
int threads_x = WARP_SIZE;
|
||||
int threads_y = OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE;
|
||||
int block_x_step1 = num_key / threads_y;
|
||||
int block_x_step2 = num_query / threads_y;
|
||||
int block_y = batch_size;
|
||||
|
||||
dim3 threads(threads_x, threads_y);
|
||||
dim3 blocks_step1(block_x_step1, block_y);
|
||||
dim3 blocks_step2(block_x_step2, block_y);
|
||||
|
||||
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||
float *value_ptr = value.data_ptr<float>();
|
||||
float *hashtable_value_ptr = hashtable_value.data_ptr<float>();
|
||||
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||
|
||||
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||
for (int weight_idx = 0; weight_idx < weight_dim; weight_idx++) {
|
||||
|
||||
cudaMemset(hashtable_value_ptr, 0, (batch_size * num_hash_f * hashtable_capacity * WARP_SIZE) * sizeof(float));
|
||||
|
||||
lsh_weighted_cumulation_ver1_step1_cuda_kernel<<<blocks_step1, threads>>>(
|
||||
key_mask_ptr,
|
||||
key_hash_code_ptr,
|
||||
key_weight_ptr,
|
||||
value_ptr,
|
||||
hashtable_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_key,
|
||||
value_dim,
|
||||
weight_dim,
|
||||
value_offset,
|
||||
weight_idx
|
||||
);
|
||||
|
||||
lsh_weighted_cumulation_ver1_step2_cuda_kernel<<<blocks_step2, threads>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
query_weight_ptr,
|
||||
hashtable_value_ptr,
|
||||
cumulation_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query,
|
||||
value_dim,
|
||||
weight_dim,
|
||||
value_offset,
|
||||
weight_idx
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return cumulation_value;
|
||||
|
||||
}
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver2_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
) {
|
||||
|
||||
int batch_size = query_hash_code.size(0);
|
||||
int num_hash_f = query_hash_code.size(2);
|
||||
|
||||
int num_query = query_hash_code.size(1);
|
||||
int num_key = key_hash_code.size(1);
|
||||
int value_dim = value.size(2);
|
||||
int weight_dim = query_weight.size(2);
|
||||
|
||||
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
||||
at::Tensor key_sorted_idxes = at::zeros({batch_size, num_hash_f, num_key}, query_hash_code.options());
|
||||
at::Tensor query_info = at::zeros({batch_size, num_query, 2, num_hash_f}, query_hash_code.options());
|
||||
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||
|
||||
if (use_cuda) {
|
||||
|
||||
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||
float *value_ptr = value.data_ptr<float>();
|
||||
|
||||
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
||||
int *key_sorted_idxes_ptr = key_sorted_idxes.data_ptr<int>();
|
||||
int *query_info_ptr = query_info.data_ptr<int>();
|
||||
|
||||
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||
|
||||
{
|
||||
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||
dim3 blocks_step13(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
||||
dim3 blocks_step2(num_hash_f, batch_size);
|
||||
int shared_mem = hashtable_capacity * sizeof(float);
|
||||
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||
key_mask_ptr,
|
||||
key_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_key
|
||||
);
|
||||
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
||||
count_sort_table_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity
|
||||
);
|
||||
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||
key_mask_ptr,
|
||||
key_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
key_sorted_idxes_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_key
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||
dim3 blocks(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
query_info_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
||||
dim3 blocks(num_query, num_hash_f, batch_size);
|
||||
int shared_mem = (weight_dim + WARP_SIZE) * sizeof(float);
|
||||
lsh_weighted_cumulation_ver2_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||
query_mask_ptr,
|
||||
query_info_ptr,
|
||||
key_sorted_idxes_ptr,
|
||||
query_weight_ptr,
|
||||
key_weight_ptr,
|
||||
value_ptr,
|
||||
cumulation_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
num_query,
|
||||
num_key,
|
||||
value_dim,
|
||||
weight_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return cumulation_value;
|
||||
|
||||
}
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver3_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
) {
|
||||
|
||||
int batch_size = query_hash_code.size(0);
|
||||
int num_hash_f = query_hash_code.size(2);
|
||||
|
||||
int num_query = query_hash_code.size(1);
|
||||
int num_key = key_hash_code.size(1);
|
||||
int value_dim = value.size(2);
|
||||
int weight_dim = query_weight.size(2);
|
||||
|
||||
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
||||
at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
|
||||
at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
|
||||
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||
|
||||
if (use_cuda) {
|
||||
|
||||
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||
float *value_ptr = value.data_ptr<float>();
|
||||
|
||||
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
||||
int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
|
||||
int *key_info_ptr = key_info.data_ptr<int>();
|
||||
|
||||
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||
|
||||
{
|
||||
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||
dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
||||
dim3 blocks_step2(num_hash_f, batch_size);
|
||||
int shared_mem = hashtable_capacity * sizeof(float);
|
||||
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query
|
||||
);
|
||||
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
||||
count_sort_table_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity
|
||||
);
|
||||
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
query_sorted_idxes_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||
dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
||||
key_mask_ptr,
|
||||
key_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
key_info_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_key
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
||||
dim3 blocks(num_key, num_hash_f, batch_size);
|
||||
int shared_mem = (weight_dim + value_dim + WARP_SIZE) * sizeof(float);
|
||||
lsh_weighted_cumulation_ver3_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||
query_sorted_idxes_ptr,
|
||||
key_mask_ptr,
|
||||
key_info_ptr,
|
||||
query_weight_ptr,
|
||||
key_weight_ptr,
|
||||
value_ptr,
|
||||
cumulation_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
num_query,
|
||||
num_key,
|
||||
value_dim,
|
||||
weight_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return cumulation_value;
|
||||
|
||||
}
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver4_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
) {
|
||||
|
||||
int batch_size = query_hash_code.size(0);
|
||||
int num_hash_f = query_hash_code.size(2);
|
||||
|
||||
int num_query = query_hash_code.size(1);
|
||||
int num_key = key_hash_code.size(1);
|
||||
int value_dim = value.size(2);
|
||||
int weight_dim = query_weight.size(2);
|
||||
|
||||
at::Tensor count_sort_table = at::zeros({batch_size, num_hash_f, hashtable_capacity}, query_hash_code.options());
|
||||
at::Tensor query_sorted_idxes = at::zeros({batch_size, num_hash_f, num_query}, query_hash_code.options());
|
||||
at::Tensor key_info = at::zeros({batch_size, num_key, 2, num_hash_f}, query_hash_code.options());
|
||||
at::Tensor cumulation_value = at::zeros({batch_size, num_query, value_dim}, value.options());
|
||||
|
||||
if (use_cuda) {
|
||||
|
||||
int *query_mask_ptr = query_mask.data_ptr<int>();
|
||||
int *query_hash_code_ptr = query_hash_code.data_ptr<int>();
|
||||
float *query_weight_ptr = query_weight.data_ptr<float>();
|
||||
int *key_mask_ptr = key_mask.data_ptr<int>();
|
||||
int *key_hash_code_ptr = key_hash_code.data_ptr<int>();
|
||||
float *key_weight_ptr = key_weight.data_ptr<float>();
|
||||
float *value_ptr = value.data_ptr<float>();
|
||||
|
||||
int *count_sort_table_ptr = count_sort_table.data_ptr<int>();
|
||||
int *query_sorted_idxes_ptr = query_sorted_idxes.data_ptr<int>();
|
||||
int *key_info_ptr = key_info.data_ptr<int>();
|
||||
|
||||
float *cumulation_value_ptr = cumulation_value.data_ptr<float>();
|
||||
|
||||
{
|
||||
dim3 threads_step13(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||
dim3 blocks_step13(num_query / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||
dim3 threads_step2(min(hashtable_capacity, OPTIMAL_THREADS_PER_BLOCK));
|
||||
dim3 blocks_step2(num_hash_f, batch_size);
|
||||
int shared_mem = hashtable_capacity * sizeof(float);
|
||||
count_sort_step1_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query
|
||||
);
|
||||
count_sort_step2_cuda_kernel<<<blocks_step2, threads_step2, shared_mem>>>(
|
||||
count_sort_table_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity
|
||||
);
|
||||
count_sort_step3_cuda_kernel<<<blocks_step13, threads_step13>>>(
|
||||
query_mask_ptr,
|
||||
query_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
query_sorted_idxes_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_query
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(num_hash_f, max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f));
|
||||
dim3 blocks(num_key / max(1, OPTIMAL_THREADS_PER_BLOCK / num_hash_f), batch_size);
|
||||
extract_query_info_cuda_kernel<<<blocks, threads>>>(
|
||||
key_mask_ptr,
|
||||
key_hash_code_ptr,
|
||||
count_sort_table_ptr,
|
||||
key_info_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
hashtable_capacity,
|
||||
num_key
|
||||
);
|
||||
}
|
||||
{
|
||||
dim3 threads(WARP_SIZE, OPTIMAL_THREADS_PER_BLOCK / WARP_SIZE);
|
||||
dim3 blocks(num_key, batch_size);
|
||||
int shared_mem = (weight_dim + value_dim + 2 * num_hash_f) * sizeof(float);
|
||||
lsh_weighted_cumulation_ver4_step2_cuda_kernel<<<blocks, threads, shared_mem>>>(
|
||||
query_sorted_idxes_ptr,
|
||||
key_mask_ptr,
|
||||
key_info_ptr,
|
||||
query_weight_ptr,
|
||||
key_weight_ptr,
|
||||
value_ptr,
|
||||
cumulation_value_ptr,
|
||||
batch_size,
|
||||
num_hash_f,
|
||||
num_query,
|
||||
num_key,
|
||||
value_dim,
|
||||
weight_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return cumulation_value;
|
||||
|
||||
}
|
71
src/transformers/models/yoso/fast_lsh_cumulation.h
Normal file
71
src/transformers/models/yoso/fast_lsh_cumulation.h
Normal file
@ -0,0 +1,71 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> fast_hash_ver1_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_vector,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_vector,
|
||||
int num_hash_f,
|
||||
int hash_code_len,
|
||||
bool use_cuda
|
||||
);
|
||||
|
||||
at::Tensor lsh_cumulation_ver1_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
);
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver1_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
);
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver2_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
);
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver3_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
);
|
||||
|
||||
at::Tensor lsh_weighted_cumulation_ver4_kernel(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_hash_code,
|
||||
at::Tensor query_weight,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_hash_code,
|
||||
at::Tensor key_weight,
|
||||
at::Tensor value,
|
||||
int hashtable_capacity,
|
||||
bool use_cuda
|
||||
);
|
825
src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu
Normal file
825
src/transformers/models/yoso/fast_lsh_cumulation_cuda.cu
Normal file
@ -0,0 +1,825 @@
|
||||
// File from https://github.com/mlpen/YOSO/blob/main/encoders/backbones/efficient_attentions/yoso/yoso_v1/cuda/fast_lsh_cumulation_cuda.cu
|
||||
|
||||
#include "fast_lsh_cumulation_cuda.h"
|
||||
#include "common_cuda_device.h"
|
||||
#include "common_cuda.h"
|
||||
#include "common.h"
|
||||
#include <stdio.h>
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void fast_hadamard_transform(float *vector_buffer, int vector_dim, int dim_idx) {
|
||||
int stride = vector_dim / 2;
|
||||
while (stride > (WARP_SIZE / 2)) {
|
||||
__syncthreads();
|
||||
int sign = 1 - ((dim_idx / stride) % 2) * 2;
|
||||
float val1 = vector_buffer[dim_idx];
|
||||
float val2 = vector_buffer[dim_idx + sign * stride];
|
||||
__syncthreads();
|
||||
vector_buffer[dim_idx] = float(sign) * val1 + val2;
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
float val = vector_buffer[dim_idx];
|
||||
#pragma unroll
|
||||
for (stride = (WARP_SIZE / 2); stride > 0; stride = stride / 2) {
|
||||
int sign = 1 - ((dim_idx / stride) % 2) * 2;
|
||||
val = float(sign) * val + __shfl_xor_sync(FULL_MASK, val, stride);
|
||||
}
|
||||
vector_buffer[dim_idx] = val;
|
||||
}
|
||||
|
||||
__global__ void fast_hash_ver1_cuda_kernel(
|
||||
int *mask, // [batch_size, num_vector]
|
||||
float *vector, // [batch_size, num_vector, vector_dim]
|
||||
int *Dmat, // [batch_size, 3, num_part, vector_dim]
|
||||
int *hash_code, // [batch_size, num_vector, num_hash_f]
|
||||
int batch_size,
|
||||
int num_vector,
|
||||
int vector_dim,
|
||||
int num_part,
|
||||
int num_hash_f,
|
||||
int hash_code_len
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.z;
|
||||
int vector_idx = blockIdx.y;
|
||||
int part_idx = blockIdx.x;
|
||||
|
||||
int dim_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__vector_idx = batch_idx * num_vector + vector_idx;
|
||||
if (mask[batch_idx__vector_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ float buffer[];
|
||||
float *vector_buffer = buffer;
|
||||
|
||||
vector_buffer[dim_idx] = vector[batch_idx__vector_idx * vector_dim + dim_idx];
|
||||
|
||||
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 0) * num_part + part_idx) * vector_dim + dim_idx];
|
||||
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
||||
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 1) * num_part + part_idx) * vector_dim + dim_idx];
|
||||
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
||||
vector_buffer[dim_idx] = vector_buffer[dim_idx] * (float)Dmat[((batch_idx * 3 + 2) * num_part + part_idx) * vector_dim + dim_idx];
|
||||
fast_hadamard_transform(vector_buffer, vector_dim, dim_idx);
|
||||
|
||||
int num_hash_per_part = vector_dim / hash_code_len;
|
||||
if (hash_code_len == 8 || hash_code_len == 16) {
|
||||
int code = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
|
||||
for (int offset = 1; offset < hash_code_len; offset = offset * 2) {
|
||||
code += __shfl_xor_sync(FULL_MASK, code, offset);
|
||||
}
|
||||
if (dim_idx % hash_code_len == 0) {
|
||||
int hash_f_idx = part_idx * num_hash_per_part + dim_idx / hash_code_len;
|
||||
if (hash_f_idx < num_hash_f) {
|
||||
hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vector_buffer[dim_idx] = select(vector_buffer[dim_idx] > 0, 1 << (dim_idx % hash_code_len), 0);
|
||||
__syncthreads();
|
||||
if (dim_idx < num_hash_per_part) {
|
||||
int code = 0;
|
||||
for (int i = 0; i < hash_code_len; i++) {
|
||||
code += vector_buffer[dim_idx * hash_code_len + i];
|
||||
}
|
||||
int hash_f_idx = part_idx * num_hash_per_part + dim_idx;
|
||||
if (hash_f_idx < num_hash_f) {
|
||||
hash_code[batch_idx__vector_idx * num_hash_f + hash_f_idx] = code;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int offset_warp
|
||||
) {
|
||||
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
if (key_mask[batch_idx__key_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_hash_f > WARP_SIZE) {
|
||||
float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||
int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||
#pragma unroll
|
||||
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
float warp_value = value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||
int warp_hashcode = 0;
|
||||
if (warp_thread_idx < num_hash_f) {
|
||||
warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
|
||||
}
|
||||
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_query,
|
||||
int value_dim,
|
||||
int offset_warp
|
||||
) {
|
||||
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
if (query_mask[batch_idx__query_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_hash_f > WARP_SIZE) {
|
||||
float warp_value = 0;
|
||||
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||
int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||
#pragma unroll
|
||||
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||
}
|
||||
}
|
||||
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
|
||||
} else {
|
||||
float warp_value = 0;
|
||||
int warp_hashcode = 0;
|
||||
if (warp_thread_idx < num_hash_f) {
|
||||
warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
|
||||
}
|
||||
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||
}
|
||||
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] = warp_value / float(num_hash_f);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim,
|
||||
int offset_warp,
|
||||
int weight_idx
|
||||
) {
|
||||
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
if (key_mask[batch_idx__key_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_hash_f > WARP_SIZE) {
|
||||
float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||
int warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||
#pragma unroll
|
||||
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
float warp_value = key_weight[batch_idx__key_idx * weight_dim + weight_idx] * value[batch_idx__key_idx * value_dim + offset_warp + warp_thread_idx];
|
||||
int warp_hashcode = 0;
|
||||
if (warp_thread_idx < num_hash_f) {
|
||||
warp_hashcode = key_hash_code[batch_idx__key_idx * num_hash_f + warp_thread_idx];
|
||||
}
|
||||
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||
atomicAdd(&hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx], warp_value);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_query,
|
||||
int value_dim,
|
||||
int weight_dim,
|
||||
int offset_warp,
|
||||
int weight_idx
|
||||
) {
|
||||
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
if (query_mask[batch_idx__query_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (num_hash_f > WARP_SIZE) {
|
||||
float warp_value = 0;
|
||||
for (int hash_f_start = 0; hash_f_start < num_hash_f; hash_f_start = hash_f_start + WARP_SIZE) {
|
||||
int warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_start + warp_thread_idx];
|
||||
#pragma unroll
|
||||
for (int hash_f_offset = 0; hash_f_offset < WARP_SIZE; hash_f_offset++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_offset);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + (hash_f_start + hash_f_offset)) * hashtable_capacity + current_hashcode;
|
||||
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||
}
|
||||
}
|
||||
float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
|
||||
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
|
||||
} else {
|
||||
float warp_value = 0;
|
||||
int warp_hashcode = 0;
|
||||
if (warp_thread_idx < num_hash_f) {
|
||||
warp_hashcode = query_hash_code[batch_idx__query_idx * num_hash_f + warp_thread_idx];
|
||||
}
|
||||
for (int hash_f_idx = 0; hash_f_idx < num_hash_f; hash_f_idx++) {
|
||||
int current_hashcode = warp_hashcode;
|
||||
current_hashcode = __shfl_sync(FULL_MASK, current_hashcode, hash_f_idx);
|
||||
int hashtable_idx = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + current_hashcode;
|
||||
warp_value = warp_value + hashtable_value[hashtable_idx * WARP_SIZE + warp_thread_idx];
|
||||
}
|
||||
float warp_weight = query_weight[batch_idx__query_idx * weight_dim + weight_idx];
|
||||
cumulation_value[batch_idx__query_idx * value_dim + offset_warp + warp_thread_idx] += warp_weight * warp_value / float(num_hash_f);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void count_sort_step1_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int hash_f_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
if (key_mask[batch_idx__key_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
|
||||
atomicAdd(&count_sort_table[(batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code], 1);
|
||||
|
||||
}
|
||||
|
||||
__global__ void count_sort_step2_cuda_kernel(
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int hash_f_idx = blockIdx.x;
|
||||
|
||||
int num_threads = blockDim.x;
|
||||
int thread_id = threadIdx.x;
|
||||
|
||||
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
||||
|
||||
extern __shared__ float buffer[];
|
||||
int *table_buffer = (int*)buffer;
|
||||
|
||||
if (thread_id == 0) {
|
||||
table_buffer[0] = 0;
|
||||
}
|
||||
copy_data<int>(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], &table_buffer[1], hashtable_capacity - 1, num_threads, thread_id);
|
||||
|
||||
for (int table_idx_start = 0; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + num_threads) {
|
||||
int thread_value = table_buffer[table_idx_start + thread_id];
|
||||
int next_thread_value = 0;
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
next_thread_value = __shfl_up_sync(FULL_MASK, thread_value, offset);
|
||||
if (thread_id % WARP_SIZE >= offset) {
|
||||
thread_value = thread_value + next_thread_value;
|
||||
}
|
||||
}
|
||||
table_buffer[table_idx_start + thread_id] = thread_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (hashtable_capacity > WARP_SIZE) {
|
||||
if (thread_id < WARP_SIZE) {
|
||||
for (int table_idx_start = WARP_SIZE; table_idx_start < hashtable_capacity; table_idx_start = table_idx_start + WARP_SIZE) {
|
||||
table_buffer[table_idx_start + thread_id] += table_buffer[table_idx_start - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
copy_data<int>(table_buffer, &count_sort_table[batch_idx__hash_f_idx * hashtable_capacity], hashtable_capacity, num_threads, thread_id);
|
||||
|
||||
}
|
||||
|
||||
|
||||
__global__ void count_sort_step3_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int key_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int hash_f_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
if (key_mask[batch_idx__key_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
||||
|
||||
int hash_code = key_hash_code[batch_idx__key_idx * num_hash_f + hash_f_idx];
|
||||
int sort_idx = atomicAdd(&count_sort_table[batch_idx__hash_f_idx * hashtable_capacity + hash_code], 1);
|
||||
key_sorted_idxes[batch_idx__hash_f_idx * num_key + sort_idx] = key_idx;
|
||||
|
||||
}
|
||||
|
||||
__global__ void extract_query_info_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_query
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int query_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
int hash_f_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
if (query_mask[batch_idx__query_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int hash_code = query_hash_code[batch_idx__query_idx * num_hash_f + hash_f_idx];
|
||||
int batch_idx__hash_f_idx__hash_code = (batch_idx * num_hash_f + hash_f_idx) * hashtable_capacity + hash_code;
|
||||
|
||||
int key_offset = select(hash_code == 0, 0, count_sort_table[batch_idx__hash_f_idx__hash_code - 1]);
|
||||
int key_count = count_sort_table[batch_idx__hash_f_idx__hash_code] - key_offset;
|
||||
|
||||
query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx] = key_offset;
|
||||
query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx] = key_count;
|
||||
|
||||
}
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int num_query,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.z;
|
||||
int hash_f_idx = blockIdx.y;
|
||||
int query_idx = blockIdx.x;
|
||||
|
||||
int num_threads = blockDim.y * blockDim.x;
|
||||
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
int num_warps = blockDim.y;
|
||||
int warp_idx = threadIdx.y;
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
if (query_mask[batch_idx__query_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int key_offset = query_info[batch_idx__query_idx * 2 * num_hash_f + hash_f_idx];
|
||||
int key_count = query_info[(batch_idx__query_idx * 2 + 1) * num_hash_f + hash_f_idx];
|
||||
|
||||
if (key_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ float buffer[];
|
||||
|
||||
if (key_count == 1) {
|
||||
if (warp_idx == 0) {
|
||||
int key_idx = key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset];
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
float weight = 0;
|
||||
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||
float val = query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||
}
|
||||
weight = weight + val;
|
||||
}
|
||||
weight = weight / float(num_hash_f);
|
||||
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||
int value_dim_idx = value_offset + warp_thread_idx;
|
||||
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
||||
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
float *weight_buffer = buffer;
|
||||
int *key_idxes_buffer = (int*)&buffer[weight_dim];
|
||||
|
||||
copy_data_nonblocking<float>(&query_weight[batch_idx__query_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
||||
|
||||
while (key_count > 0) {
|
||||
int work_size = min(WARP_SIZE, key_count);
|
||||
copy_data_nonblocking<int>(&key_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_key + key_offset], key_idxes_buffer, work_size, num_threads, thread_id);
|
||||
__syncthreads();
|
||||
for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
|
||||
int work_idx = work_offset + warp_idx;
|
||||
if (work_idx < key_count) {
|
||||
int key_idx = key_idxes_buffer[work_idx];
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
float weight = 0;
|
||||
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||
float val = weight_buffer[weight_dim_idx] * key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx];
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||
}
|
||||
weight = weight + val;
|
||||
}
|
||||
weight = weight / float(num_hash_f);
|
||||
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||
int value_dim_idx = value_offset + warp_thread_idx;
|
||||
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
||||
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||
}
|
||||
}
|
||||
}
|
||||
key_count = key_count - work_size;
|
||||
key_offset = key_offset + work_size;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
|
||||
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int num_query,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.z;
|
||||
int hash_f_idx = blockIdx.y;
|
||||
int key_idx = blockIdx.x;
|
||||
|
||||
int num_threads = blockDim.y * blockDim.x;
|
||||
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
int num_warps = blockDim.y;
|
||||
int warp_idx = threadIdx.y;
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
if (key_mask[batch_idx__key_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int query_offset = key_info[batch_idx__key_idx * 2 * num_hash_f + hash_f_idx];
|
||||
int query_count = key_info[(batch_idx__key_idx * 2 + 1) * num_hash_f + hash_f_idx];
|
||||
|
||||
if (query_count == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ float buffer[];
|
||||
|
||||
if (query_count == 1) {
|
||||
if (warp_idx == 0) {
|
||||
int query_idx = query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset];
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
float weight = 0;
|
||||
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||
float val = key_weight[batch_idx__key_idx * weight_dim + weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||
}
|
||||
weight = weight + val;
|
||||
}
|
||||
weight = weight / float(num_hash_f);
|
||||
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||
int value_dim_idx = value_offset + warp_thread_idx;
|
||||
float val = value[batch_idx__key_idx * value_dim + value_dim_idx];
|
||||
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
float *weight_buffer = buffer;
|
||||
float *value_buffer = &buffer[weight_dim];
|
||||
int *query_idxes_buffer = (int*)&buffer[weight_dim + value_dim];
|
||||
|
||||
copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
||||
copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
|
||||
|
||||
while (query_count > 0) {
|
||||
int work_size = min(WARP_SIZE, query_count);
|
||||
copy_data_nonblocking<int>(&query_sorted_idxes[(batch_idx * num_hash_f + hash_f_idx) * num_query + query_offset], query_idxes_buffer, work_size, num_threads, thread_id);
|
||||
__syncthreads();
|
||||
for (int work_offset = 0; work_offset < WARP_SIZE; work_offset = work_offset + num_warps) {
|
||||
int work_idx = work_offset + warp_idx;
|
||||
if (work_idx < query_count) {
|
||||
int query_idx = query_idxes_buffer[work_idx];
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
float weight = 0;
|
||||
for (int weight_offset = 0; weight_offset < weight_dim; weight_offset = weight_offset + WARP_SIZE) {
|
||||
int weight_dim_idx = weight_offset + warp_thread_idx;
|
||||
float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||
}
|
||||
weight = weight + val;
|
||||
}
|
||||
weight = weight / float(num_hash_f);
|
||||
for (int value_offset = 0; value_offset < value_dim; value_offset = value_offset + WARP_SIZE) {
|
||||
int value_dim_idx = value_offset + warp_thread_idx;
|
||||
float val = value_buffer[value_dim_idx];
|
||||
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||
}
|
||||
}
|
||||
}
|
||||
query_count = query_count - work_size;
|
||||
query_offset = query_offset + work_size;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
|
||||
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int num_query,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim
|
||||
) {
|
||||
|
||||
int batch_idx = blockIdx.y;
|
||||
int key_idx = blockIdx.x;
|
||||
|
||||
int num_threads = blockDim.y * blockDim.x;
|
||||
int thread_id = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
int num_warps = blockDim.y;
|
||||
int warp_idx = threadIdx.y;
|
||||
int warp_thread_idx = threadIdx.x;
|
||||
|
||||
int batch_idx__key_idx = batch_idx * num_key + key_idx;
|
||||
if (key_mask[batch_idx__key_idx] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
extern __shared__ float buffer[];
|
||||
float *weight_buffer = buffer;
|
||||
float *value_buffer = &buffer[weight_dim];
|
||||
int *key_info_buffer = (int*)&buffer[weight_dim + value_dim];
|
||||
|
||||
copy_data_nonblocking<float>(&key_weight[batch_idx__key_idx * weight_dim], weight_buffer, weight_dim, num_threads, thread_id);
|
||||
copy_data_nonblocking<float>(&value[batch_idx__key_idx * value_dim], value_buffer, value_dim, num_threads, thread_id);
|
||||
copy_data_nonblocking<int>(&key_info[batch_idx__key_idx * 2 * num_hash_f], key_info_buffer, 2 * num_hash_f, num_threads, thread_id);
|
||||
|
||||
int *query_offset_buffer = key_info_buffer;
|
||||
int *query_count_buffer = &key_info_buffer[num_hash_f];
|
||||
|
||||
const int hashtable_size = 1024 + OPTIMAL_THREADS_PER_BLOCK;
|
||||
__shared__ int hashtable_query[hashtable_size];
|
||||
__shared__ int hashtable_count[hashtable_size];
|
||||
__shared__ int inserted_query[hashtable_size];
|
||||
__shared__ int query_counter[1];
|
||||
|
||||
int hash_f_idx_base = 0;
|
||||
|
||||
while (true) {
|
||||
|
||||
init_buffer_nonblocking<int>(EMPTY_VALUE, hashtable_query, hashtable_size, num_threads, thread_id);
|
||||
init_buffer_nonblocking<int>(0, hashtable_count, hashtable_size, num_threads, thread_id);
|
||||
init_buffer_nonblocking<int>(EMPTY_VALUE, inserted_query, hashtable_size, num_threads, thread_id);
|
||||
init_buffer_nonblocking<int>(0, query_counter, 1, num_threads, thread_id);
|
||||
__syncthreads();
|
||||
|
||||
while (hash_f_idx_base < num_hash_f) {
|
||||
|
||||
int hash_f_idx = hash_f_idx_base + warp_idx;
|
||||
int batch_idx__hash_f_idx = batch_idx * num_hash_f + hash_f_idx;
|
||||
|
||||
int stop_flag = 0;
|
||||
|
||||
int query_offset = query_offset_buffer[hash_f_idx];
|
||||
int query_count = query_count_buffer[hash_f_idx];
|
||||
|
||||
while (query_count > 0) {
|
||||
|
||||
int work_size = min(query_count, WARP_SIZE);
|
||||
|
||||
// try inserting query to set and check whether the query is new
|
||||
int found_new_query = 0;
|
||||
int query_idx = -1;
|
||||
if (warp_thread_idx < work_size) {
|
||||
query_idx = query_sorted_idxes[batch_idx__hash_f_idx * num_query + query_offset + warp_thread_idx];
|
||||
int slot = set_insert<int>(hashtable_query, hashtable_size, query_idx);
|
||||
if (slot >= 0) {
|
||||
found_new_query = atomicAdd(&hashtable_count[slot], 1) == 0;
|
||||
}
|
||||
}
|
||||
|
||||
// compute cumulative offset
|
||||
int position_offset = found_new_query;
|
||||
int next_position_offset = 0;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
next_position_offset = __shfl_up_sync(FULL_MASK, position_offset, offset);
|
||||
if (thread_id % WARP_SIZE >= offset) {
|
||||
position_offset = position_offset + next_position_offset;
|
||||
}
|
||||
}
|
||||
|
||||
// get the inserted query list end index
|
||||
int inserted_query_base = 0;
|
||||
if (thread_id % WARP_SIZE == WARP_SIZE - 1) {
|
||||
inserted_query_base = atomicAdd(query_counter, position_offset);
|
||||
}
|
||||
inserted_query_base = __shfl_sync(FULL_MASK, inserted_query_base, WARP_SIZE - 1);
|
||||
|
||||
// insert new queries to list
|
||||
int insert_idx = inserted_query_base + position_offset - 1;
|
||||
if (found_new_query) {
|
||||
inserted_query[insert_idx] = query_idx;
|
||||
}
|
||||
|
||||
// remove inserted queries from list
|
||||
query_offset_buffer[hash_f_idx] += work_size;
|
||||
query_count_buffer[hash_f_idx] -= work_size;
|
||||
query_offset += work_size;
|
||||
query_count -= work_size;
|
||||
|
||||
// if list is almost full, stop inserting
|
||||
if (inserted_query_base + OPTIMAL_THREADS_PER_BLOCK > hashtable_size) {
|
||||
stop_flag = 1;
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (stop_flag) {
|
||||
break;
|
||||
}
|
||||
|
||||
hash_f_idx_base = hash_f_idx_base + num_warps;
|
||||
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int num_distint_query = query_counter[0];
|
||||
|
||||
if (num_distint_query > 0) {
|
||||
for (int idx_base = 0; idx_base < num_distint_query; idx_base = idx_base + num_warps) {
|
||||
int idx = idx_base + warp_idx;
|
||||
if (idx < num_distint_query) {
|
||||
int query_idx = inserted_query[idx];
|
||||
int batch_idx__query_idx = batch_idx * num_query + query_idx;
|
||||
|
||||
int slot = set_lookup<int>(hashtable_query, hashtable_size, query_idx);
|
||||
int duplicate_count = hashtable_count[slot];
|
||||
|
||||
float weight = 0;
|
||||
for (int weight_idx_base = 0; weight_idx_base < weight_dim; weight_idx_base = weight_idx_base + WARP_SIZE) {
|
||||
int weight_dim_idx = weight_idx_base + warp_thread_idx;
|
||||
float val = weight_buffer[weight_dim_idx] * query_weight[batch_idx__query_idx * weight_dim + weight_dim_idx];
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset = offset << 1) {
|
||||
val += __shfl_xor_sync(FULL_MASK, val, offset);
|
||||
}
|
||||
weight = weight + val;
|
||||
}
|
||||
|
||||
weight = (float)duplicate_count * weight / float(num_hash_f);
|
||||
|
||||
for (int value_idx_base = 0; value_idx_base < value_dim; value_idx_base = value_idx_base + WARP_SIZE) {
|
||||
int value_dim_idx = value_idx_base + warp_thread_idx;
|
||||
float val = value_buffer[value_dim_idx];
|
||||
atomicAdd(&cumulation_value[batch_idx__query_idx * value_dim + value_dim_idx], weight * val);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
// all computation is completed if num_distint_query == 0
|
||||
break;
|
||||
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
|
||||
}
|
157
src/transformers/models/yoso/fast_lsh_cumulation_cuda.h
Normal file
157
src/transformers/models/yoso/fast_lsh_cumulation_cuda.h
Normal file
@ -0,0 +1,157 @@
|
||||
__global__ void fast_hash_ver1_cuda_kernel(
|
||||
int *mask, // [batch_size, num_vector]
|
||||
float *vector, // [batch_size, num_vector, vector_dim]
|
||||
int *Dmat, // [3, num_part, vector_dim]
|
||||
int *hash_code, // [batch_size, num_vector, num_hash_f]
|
||||
int batch_size,
|
||||
int num_vector,
|
||||
int vector_dim,
|
||||
int num_part,
|
||||
int num_hash_f,
|
||||
int hash_code_len
|
||||
);
|
||||
|
||||
__global__ void lsh_cumulation_ver1_step1_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int offset_warp
|
||||
);
|
||||
|
||||
__global__ void lsh_cumulation_ver1_step2_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_query,
|
||||
int value_dim,
|
||||
int offset_warp
|
||||
);
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver1_step1_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim,
|
||||
int offset_warp,
|
||||
int weight_idx
|
||||
);
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver1_step2_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *hashtable_value, // [batch_size, num_hash_f, hashtable_capacity, WARP_SIZE]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_query,
|
||||
int value_dim,
|
||||
int weight_dim,
|
||||
int offset_warp,
|
||||
int weight_idx
|
||||
);
|
||||
|
||||
__global__ void count_sort_step1_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key
|
||||
);
|
||||
|
||||
__global__ void count_sort_step2_cuda_kernel(
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity
|
||||
);
|
||||
|
||||
__global__ void count_sort_step3_cuda_kernel(
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_key
|
||||
);
|
||||
|
||||
__global__ void extract_query_info_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
int *count_sort_table, // [batch_size, num_hash_f, hashtable_capacity]
|
||||
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int hashtable_capacity,
|
||||
int num_query
|
||||
);
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver2_step2_cuda_kernel(
|
||||
int *query_mask, // [batch_size, num_query]
|
||||
int *query_info, // [batch_size, num_query, 2, num_hash_f]
|
||||
int *key_sorted_idxes, // [batch_size, num_hash_f, num_key]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int num_query,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim
|
||||
);
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver3_step2_cuda_kernel(
|
||||
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int num_query,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim
|
||||
);
|
||||
|
||||
__global__ void lsh_weighted_cumulation_ver4_step2_cuda_kernel(
|
||||
int *query_sorted_idxes, // [batch_size, num_hash_f, num_query]
|
||||
int *key_mask, // [batch_size, num_key]
|
||||
int *key_info, // [batch_size, num_key, 2, num_hash_f]
|
||||
float *query_weight, // [batch_size, num_query, weight_dim]
|
||||
float *key_weight, // [batch_size, num_key, weight_dim]
|
||||
float *value, // [batch_size, num_key, value_dim]
|
||||
float *cumulation_value, // [batch_size, num_query, value_dim]
|
||||
int batch_size,
|
||||
int num_hash_f,
|
||||
int num_query,
|
||||
int num_key,
|
||||
int value_dim,
|
||||
int weight_dim
|
||||
);
|
128
src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp
Normal file
128
src/transformers/models/yoso/fast_lsh_cumulation_torch.cpp
Normal file
@ -0,0 +1,128 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include "fast_lsh_cumulation.h"
|
||||
#include "common_cuda.h"
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> fast_hash(
|
||||
at::Tensor query_mask,
|
||||
at::Tensor query_vector,
|
||||
at::Tensor key_mask,
|
||||
at::Tensor key_vector,
|
||||
int num_hash_f,
|
||||
int hash_code_len,
|
||||
bool use_cuda,
|
||||
int version
|
||||
) {
|
||||
return fast_hash_ver1_kernel(
|
||||
query_mask,
|
||||
query_vector,
|
||||
key_mask,
|
||||
key_vector,
|
||||
num_hash_f,
|
||||
hash_code_len,
|
||||
use_cuda
|
||||
);
|
||||
}
|
||||
|
||||
at::Tensor lsh_cumulation(
|
||||
at::Tensor query_mask, // [batch_size, num_query]
|
||||
at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
at::Tensor key_mask, // [batch_size, num_key]
|
||||
at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
at::Tensor value, // [batch_size, num_key, value_dim]
|
||||
int hashtable_capacity,
|
||||
bool use_cuda,
|
||||
int version
|
||||
) {
|
||||
return lsh_cumulation_ver1_kernel(
|
||||
query_mask,
|
||||
query_hash_code,
|
||||
key_mask,
|
||||
key_hash_code,
|
||||
value,
|
||||
hashtable_capacity,
|
||||
use_cuda
|
||||
);
|
||||
}
|
||||
|
||||
at::Tensor lsh_weighted_cumulation(
|
||||
at::Tensor query_mask, // [batch_size, num_query]
|
||||
at::Tensor query_hash_code, // [batch_size, num_query, num_hash_f]
|
||||
at::Tensor query_weight, // [batch_size, num_query, weight_dim]
|
||||
at::Tensor key_mask, // [batch_size, num_key]
|
||||
at::Tensor key_hash_code, // [batch_size, num_key, num_hash_f]
|
||||
at::Tensor key_weight, // [batch_size, num_key, weight_dim]
|
||||
at::Tensor value, // [batch_size, num_key, value_dim]
|
||||
int hashtable_capacity,
|
||||
bool use_cuda,
|
||||
int version
|
||||
) {
|
||||
if (version == 1) {
|
||||
return lsh_weighted_cumulation_ver1_kernel(
|
||||
query_mask,
|
||||
query_hash_code,
|
||||
query_weight,
|
||||
key_mask,
|
||||
key_hash_code,
|
||||
key_weight,
|
||||
value,
|
||||
hashtable_capacity,
|
||||
use_cuda
|
||||
);
|
||||
} else if (version == 2) {
|
||||
return lsh_weighted_cumulation_ver2_kernel(
|
||||
query_mask,
|
||||
query_hash_code,
|
||||
query_weight,
|
||||
key_mask,
|
||||
key_hash_code,
|
||||
key_weight,
|
||||
value,
|
||||
hashtable_capacity,
|
||||
use_cuda
|
||||
);
|
||||
} else if (version == 3) {
|
||||
return lsh_weighted_cumulation_ver3_kernel(
|
||||
query_mask,
|
||||
query_hash_code,
|
||||
query_weight,
|
||||
key_mask,
|
||||
key_hash_code,
|
||||
key_weight,
|
||||
value,
|
||||
hashtable_capacity,
|
||||
use_cuda
|
||||
);
|
||||
} else if (version == 4) {
|
||||
return lsh_weighted_cumulation_ver4_kernel(
|
||||
query_mask,
|
||||
query_hash_code,
|
||||
query_weight,
|
||||
key_mask,
|
||||
key_hash_code,
|
||||
key_weight,
|
||||
value,
|
||||
hashtable_capacity,
|
||||
use_cuda
|
||||
);
|
||||
} else {
|
||||
return lsh_weighted_cumulation_ver3_kernel(
|
||||
query_mask,
|
||||
query_hash_code,
|
||||
query_weight,
|
||||
key_mask,
|
||||
key_hash_code,
|
||||
key_weight,
|
||||
value,
|
||||
hashtable_capacity,
|
||||
use_cuda
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fast_hash", &fast_hash, "Fast Hash (CUDA)");
|
||||
m.def("lsh_cumulation", &lsh_cumulation, "LSH Cumulation (CUDA)");
|
||||
m.def("lsh_weighted_cumulation", &lsh_weighted_cumulation, "LSH Weighted Cumulation (CUDA)");
|
||||
}
|
1324
src/transformers/models/yoso/modeling_yoso.py
Normal file
1324
src/transformers/models/yoso/modeling_yoso.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -4060,6 +4060,65 @@ def load_tf_weights_in_xlnet(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_xlnet, ["torch"])
|
||||
|
||||
|
||||
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class YosoForMaskedLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoForMultipleChoice(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoForQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoForSequenceClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoForTokenClassification(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoLayer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class YosoPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Adafactor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
401
tests/test_modeling_yoso.py
Normal file
401
tests/test_modeling_yoso.py
Normal file
@ -0,0 +1,401 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch YOSO model. """
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import YosoConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
YosoForMaskedLM,
|
||||
YosoForMultipleChoice,
|
||||
YosoForQuestionAnswering,
|
||||
YosoForSequenceClassification,
|
||||
YosoForTokenClassification,
|
||||
YosoModel,
|
||||
)
|
||||
from transformers.models.yoso.modeling_yoso import YOSO_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class YosoModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return YosoConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = self.prepare_config_and_inputs()
|
||||
|
||||
config.is_decoder = True
|
||||
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = YosoModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = YosoModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = YosoForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = YosoForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = YosoForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = YosoForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = YosoForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class YosoModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
YosoModel,
|
||||
YosoForMaskedLM,
|
||||
YosoForMultipleChoice,
|
||||
YosoForQuestionAnswering,
|
||||
YosoForSequenceClassification,
|
||||
YosoForTokenClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
|
||||
all_generative_model_classes = ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = YosoModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=YosoConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in YOSO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = YosoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
return
|
||||
|
||||
|
||||
@require_torch
|
||||
class YosoModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = YosoModel.from_pretrained("uw-madison/yoso-4096")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 6, 768))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[-0.0611, 0.1242, 0.0840], [0.0280, -0.0048, 0.1125], [0.0106, 0.0226, 0.0751]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 50265
|
||||
|
||||
expected_shape = torch.Size((1, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[-2.1313, -3.7285, -2.2407], [-2.7047, -3.3314, -2.6408], [0.0629, -2.5166, -0.3356]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm_long_input(self):
|
||||
model = YosoForMaskedLM.from_pretrained("uw-madison/yoso-4096")
|
||||
input_ids = torch.arange(4096).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)[0]
|
||||
|
||||
vocab_size = 50265
|
||||
|
||||
expected_shape = torch.Size((1, 4096, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor(
|
||||
[[[-2.3914, -4.3742, -5.0956], [-4.0988, -4.2384, -7.0406], [-3.1427, -3.7192, -6.6800]]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
Loading…
Reference in New Issue
Block a user