mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add REALM (#13292)
* REALM initial commit * Retriever OK (Update new_gelu). * Encoder prediction score OK * Encoder pretrained model OK * Update retriever comments * Update docs, tests, and imports * Prune unused models * Make embedder as a module `RealmEmbedder` * Add RealmRetrieverOutput * Update tokenization * Pass all tests in test_modeling_realm.py * Prune RealmModel * Update docs * Add training test. * Remove completed TODO * Style & Quality * Prune `RealmModel` * Fixup * Changes: 1. Remove RealmTokenizerFast 2. Update docstrings 3. Add a method to RealmTokenizer to handle candidates tokenization. * Fix up * Style * Add tokenization tests * Update `from_pretrained` tests * Apply suggestions * Style & Quality * Copy BERT model * Fix comment to avoid docstring copying * Make RealmBertModel private * Fix bug * Style * Basic QA * Save * Complete reader logits * Add searcher * Complete searcher & reader * Move block records init to constructor * Fix training bug * Add some outputs to RealmReader * Add finetuned checkpoint variable names parsing * Fix bug * Update REALM config * Add RealmForOpenQA * Update convert_tfrecord logits * Fix bugs * Complete imports * Update docs * Update naming * Add brute-force searcher * Pass realm model tests * Style * Exclude RealmReader from common tests * Fix * Fix * convert docs * up * up * more make style * up * upload * up * Fix * Update src/transformers/__init__.py * adapt testing * change modeling code * fix test * up * up * up * correct more * make retriever work * update * make style * finish main structure * Resolve merge conflict * Make everything work * Style * Fixup * Fixup * Update training test * fix retriever * remove hardcoded path * Fix * Fix modeling test * Update model links * Initial retrieval test * Fix modeling test * Complete retrieval tests * Fix * style * Fix tests * Fix docstring example * Minor fix of retrieval test * Update license headers and docs * Apply suggestions from code review * Style * Apply suggestions from code review * Add an example to RealmEmbedder * Fix Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b25067d807
commit
22454ae492
@ -291,6 +291,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
1. **[ProphetNet](https://huggingface.co/docs/transformers/model_doc/prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius.
|
||||
1. **[REALM](https://huggingface.co/transformers/master/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
|
||||
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
|
||||
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
|
@ -270,6 +270,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
1. **[ProphetNet](https://huggingface.co/docs/transformers/model_doc/prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius.
|
||||
1. **[REALM](https://huggingface.co/transformers/master/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
|
||||
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
|
||||
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
|
@ -294,6 +294,7 @@ conda install -c huggingface transformers
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (来自 VinAI Research) 伴随论文 [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) 由 Dat Quoc Nguyen and Anh Tuan Nguyen 发布。
|
||||
1. **[ProphetNet](https://huggingface.co/docs/transformers/model_doc/prophetnet)** (来自 Microsoft Research) 伴随论文 [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) 由 Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou 发布。
|
||||
1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (来自 NVIDIA) 伴随论文 [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) 由 Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius 发布。
|
||||
1. **[REALM](https://huggingface.co/transformers/master/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。
|
||||
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。
|
||||
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。
|
||||
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。
|
||||
|
@ -306,6 +306,7 @@ conda install -c huggingface transformers
|
||||
1. **[PhoBERT](https://huggingface.co/docs/transformers/model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
1. **[ProphetNet](https://huggingface.co/docs/transformers/model_doc/prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[QDQBert](https://huggingface.co/docs/transformers/model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius.
|
||||
1. **[REALM](https://huggingface.co/transformers/master/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
|
||||
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
|
||||
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
|
@ -240,6 +240,8 @@
|
||||
title: QDQBert
|
||||
- local: model_doc/rag
|
||||
title: RAG
|
||||
- local: model_doc/realm
|
||||
title: REALM
|
||||
- local: model_doc/reformer
|
||||
title: Reformer
|
||||
- local: model_doc/rembert
|
||||
|
@ -151,6 +151,7 @@ conversion utilities for the following models.
|
||||
1. **[PhoBERT](model_doc/phobert)** (from VinAI Research) released with the paper [PhoBERT: Pre-trained language models for Vietnamese](https://www.aclweb.org/anthology/2020.findings-emnlp.92/) by Dat Quoc Nguyen and Anh Tuan Nguyen.
|
||||
1. **[ProphetNet](model_doc/prophetnet)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
1. **[QDQBert](model_doc/qdqbert)** (from NVIDIA) released with the paper [Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation](https://arxiv.org/abs/2004.09602) by Hao Wu, Patrick Judd, Xiaojie Zhang, Mikhail Isaev and Paulius Micikevicius.
|
||||
1. **[REALM](https://huggingface.co/transformers/master/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
|
||||
1. **[Reformer](model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
1. **[RemBERT](model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
|
||||
1. **[RoBERTa](model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
|
||||
@ -244,6 +245,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| QDQBert | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Realm | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
80
docs/source/model_doc/realm.mdx
Normal file
80
docs/source/model_doc/realm.mdx
Normal file
@ -0,0 +1,80 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# REALM
|
||||
|
||||
## Overview
|
||||
|
||||
The REALM model was proposed in `REALM: Retrieval-Augmented Language Model Pre-Training
|
||||
<https://arxiv.org/abs/2002.08909>`__ by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang. It's a
|
||||
retrieval-augmented language model that firstly retrieves documents from a textual knowledge corpus and then
|
||||
utilizes retrieved documents to process question answering tasks.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Language model pre-training has been shown to capture a surprising amount of world knowledge, crucial for NLP tasks
|
||||
such as question answering. However, this knowledge is stored implicitly in the parameters of a neural network,
|
||||
requiring ever-larger networks to cover more facts. To capture knowledge in a more modular and interpretable way, we
|
||||
augment language model pre-training with a latent knowledge retriever, which allows the model to retrieve and attend
|
||||
over documents from a large corpus such as Wikipedia, used during pre-training, fine-tuning and inference. For the
|
||||
first time, we show how to pre-train such a knowledge retriever in an unsupervised manner, using masked language
|
||||
modeling as the learning signal and backpropagating through a retrieval step that considers millions of documents. We
|
||||
demonstrate the effectiveness of Retrieval-Augmented Language Model pre-training (REALM) by fine-tuning on the
|
||||
challenging task of Open-domain Question Answering (Open-QA). We compare against state-of-the-art models for both
|
||||
explicit and implicit knowledge storage on three popular Open-QA benchmarks, and find that we outperform all previous
|
||||
methods by a significant margin (4-16% absolute accuracy), while also providing qualitative benefits such as
|
||||
interpretability and modularity.*
|
||||
|
||||
This model was contributed by `qqaatw <https://huggingface.co/qqaatw>`__. The original code can be found `here
|
||||
<https://github.com/google-research/language/tree/master/language/realm>`__.
|
||||
|
||||
## RealmConfig
|
||||
|
||||
[[autodoc]] RealmConfig
|
||||
|
||||
## RealmTokenizer
|
||||
|
||||
[[autodoc]] RealmTokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
- get_special_tokens_mask
|
||||
- create_token_type_ids_from_sequences
|
||||
- save_vocabulary
|
||||
- batch_encode_candidates
|
||||
|
||||
## RealmRetriever
|
||||
|
||||
[[autodoc]] RealmRetriever
|
||||
|
||||
## RealmEmbedder
|
||||
|
||||
[[autodoc]] RealmEmbedder
|
||||
- forward
|
||||
|
||||
## RealmScorer
|
||||
|
||||
[[autodoc]] RealmScorer
|
||||
- forward
|
||||
|
||||
## RealmKnowledgeAugEncoder
|
||||
|
||||
[[autodoc]] RealmKnowledgeAugEncoder
|
||||
- forward
|
||||
|
||||
## RealmReader
|
||||
|
||||
[[autodoc]] RealmReader
|
||||
- forward
|
||||
|
||||
## RealmForOpenQA
|
||||
|
||||
[[autodoc]] RealmForOpenQA
|
||||
- forward
|
@ -265,6 +265,7 @@ _import_structure = {
|
||||
"models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"],
|
||||
"models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
|
||||
"models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"],
|
||||
"models.realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig", "RealmTokenizer"],
|
||||
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
|
||||
"models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
|
||||
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
|
||||
@ -1199,6 +1200,19 @@ if is_torch_available():
|
||||
_import_structure["models.rag"].extend(
|
||||
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
||||
)
|
||||
_import_structure["models.realm"].extend(
|
||||
[
|
||||
"REALM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"RealmEmbedder",
|
||||
"RealmForOpenQA",
|
||||
"RealmKnowledgeAugEncoder",
|
||||
"RealmPreTrainedModel",
|
||||
"RealmReader",
|
||||
"RealmRetriever",
|
||||
"RealmScorer",
|
||||
"load_tf_weights_in_realm",
|
||||
]
|
||||
)
|
||||
_import_structure["models.reformer"].extend(
|
||||
[
|
||||
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -2353,6 +2367,7 @@ if TYPE_CHECKING:
|
||||
from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer
|
||||
from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
|
||||
from .models.rag import RagConfig, RagRetriever, RagTokenizer
|
||||
from .models.realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig, RealmTokenizer
|
||||
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
|
||||
from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
|
||||
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
|
||||
@ -3128,6 +3143,17 @@ if TYPE_CHECKING:
|
||||
ProphetNetPreTrainedModel,
|
||||
)
|
||||
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||
from .models.realm import (
|
||||
REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RealmEmbedder,
|
||||
RealmForOpenQA,
|
||||
RealmKnowledgeAugEncoder,
|
||||
RealmPreTrainedModel,
|
||||
RealmReader,
|
||||
RealmRetriever,
|
||||
RealmScorer,
|
||||
load_tf_weights_in_realm,
|
||||
)
|
||||
from .models.reformer import (
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
ReformerAttention,
|
||||
|
@ -84,6 +84,7 @@ from . import (
|
||||
prophetnet,
|
||||
qdqbert,
|
||||
rag,
|
||||
realm,
|
||||
reformer,
|
||||
rembert,
|
||||
retribert,
|
||||
|
@ -30,6 +30,7 @@ logger = logging.get_logger(__name__)
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add configs here
|
||||
("realm", "RealmConfig"),
|
||||
("nystromformer", "NystromformerConfig"),
|
||||
("imagegpt", "ImageGPTConfig"),
|
||||
("qdqbert", "QDQBertConfig"),
|
||||
@ -117,6 +118,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Add archive maps here
|
||||
("realm", "REALM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("nystromformer", "NYSTROMFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("imagegpt", "IMAGEGPT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
("qdqbert", "QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
|
||||
@ -192,6 +194,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
|
||||
MODEL_NAMES_MAPPING = OrderedDict(
|
||||
[
|
||||
# Add full (and cased) model names here
|
||||
("realm", "Realm"),
|
||||
("nystromformer", "Nystromformer"),
|
||||
("imagegpt", "ImageGPT"),
|
||||
("qdqbert", "QDQBert"),
|
||||
|
64
src/transformers/models/realm/__init__.py
Normal file
64
src/transformers/models/realm/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig"],
|
||||
"tokenization_realm": ["RealmTokenizer"],
|
||||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_realm"] = [
|
||||
"REALM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"RealmEmbedder",
|
||||
"RealmForOpenQA",
|
||||
"RealmKnowledgeAugEncoder",
|
||||
"RealmPreTrainedModel",
|
||||
"RealmReader",
|
||||
"RealmScorer",
|
||||
"load_tf_weights_in_realm",
|
||||
]
|
||||
_import_structure["retrieval_realm"] = ["RealmRetriever"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig
|
||||
from .tokenization_realm import RealmTokenizer
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_realm import (
|
||||
REALM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
RealmEmbedder,
|
||||
RealmForOpenQA,
|
||||
RealmKnowledgeAugEncoder,
|
||||
RealmPreTrainedModel,
|
||||
RealmReader,
|
||||
RealmScorer,
|
||||
load_tf_weights_in_realm,
|
||||
)
|
||||
from .retrieval_realm import RealmRetriever
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
|
180
src/transformers/models/realm/configuration_realm.py
Normal file
180
src/transformers/models/realm/configuration_realm.py
Normal file
@ -0,0 +1,180 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" REALM model configuration."""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
REALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/config.json",
|
||||
"realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/config.json",
|
||||
"realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/config.json",
|
||||
"realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/config.json",
|
||||
"realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/config.json",
|
||||
"realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/config.json",
|
||||
"realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/config.json",
|
||||
"realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/config.json",
|
||||
# See all REALM models at https://huggingface.co/models?filter=realm
|
||||
}
|
||||
|
||||
|
||||
class RealmConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of
|
||||
|
||||
1. [`RealmEmbedder`]
|
||||
2. [`RealmScorer`]
|
||||
3. [`RealmKnowledgeAugEncoder`]
|
||||
4. [`RealmRetriever`]
|
||||
5. [`RealmReader`]
|
||||
6. [`RealmForOpenQA`]
|
||||
|
||||
It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
|
||||
[realm-cc-news-pretrained](https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 30522):
|
||||
Vocabulary size of the REALM model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`RealmEmbedder`], [`RealmScorer`], [`RealmKnowledgeAugEncoder`], or
|
||||
[`RealmReader`].
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimension of the encoder layers and the pooler layer.
|
||||
retriever_proj_size (`int`, *optional*, defaults to 128):
|
||||
Dimension of the retriever(embedder) projection.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_candidates (`int`, *optional*, defaults to 8):
|
||||
Number of candidates inputted to the RealmScorer or RealmKnowledgeAugEncoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (`int`, *optional*, defaults to 2):
|
||||
The vocabulary size of the `token_type_ids` passed when calling [`RealmEmbedder`], [`RealmScorer`],
|
||||
[`RealmKnowledgeAugEncoder`], or [`RealmReader`].
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
span_hidden_size (`int`, *optional*, defaults to 256):
|
||||
Dimension of the reader's spans.
|
||||
max_span_width (`int`, *optional*, defaults to 10):
|
||||
Max span width of the reader.
|
||||
reader_layer_norm_eps (`float`, *optional*, defaults to 1e-3):
|
||||
The epsilon used by the reader's layer normalization layers.
|
||||
reader_beam_size (`int`, *optional*, defaults to 5):
|
||||
Beam size of the reader.
|
||||
reader_seq_len (`int`, *optional*, defaults to 288+32):
|
||||
Maximum sequence length of the reader.
|
||||
num_block_records (`int`, *optional*, defaults to 13353718):
|
||||
Number of block records.
|
||||
searcher_beam_size (`int`, *optional*, defaults to 5000):
|
||||
Beam size of the searcher. Note that when eval mode is enabled, *searcher_beam_size* will be the same as
|
||||
*reader_beam_size*.
|
||||
searcher_seq_len (`int`, *optional*, defaults to 64):
|
||||
Maximum sequence length of the searcher.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import RealmEmbedder, RealmConfig
|
||||
|
||||
>>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
|
||||
>>> configuration = RealmConfig()
|
||||
|
||||
>>> # Initializing a model from the qqaatw/realm-cc-news-pretrained-embedder style configuration
|
||||
>>> model = RealmEmbedder(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
model_type = "realm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
retriever_proj_size=128,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_candidates=8,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu_new",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
span_hidden_size=256,
|
||||
max_span_width=10,
|
||||
reader_layer_norm_eps=1e-3,
|
||||
reader_beam_size=5,
|
||||
reader_seq_len=320, # 288 + 32
|
||||
num_block_records=13353718,
|
||||
searcher_beam_size=5000,
|
||||
searcher_seq_len=64,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
# Common config
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.retriever_proj_size = retriever_proj_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_candidates = num_candidates
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
# Reader config
|
||||
self.span_hidden_size = span_hidden_size
|
||||
self.max_span_width = max_span_width
|
||||
self.reader_layer_norm_eps = reader_layer_norm_eps
|
||||
self.reader_beam_size = reader_beam_size
|
||||
self.reader_seq_len = reader_seq_len
|
||||
|
||||
# Retrieval config
|
||||
self.num_block_records = num_block_records
|
||||
self.searcher_beam_size = searcher_beam_size
|
||||
self.searcher_seq_len = searcher_seq_len
|
1842
src/transformers/models/realm/modeling_realm.py
Normal file
1842
src/transformers/models/realm/modeling_realm.py
Normal file
File diff suppressed because it is too large
Load Diff
162
src/transformers/models/realm/retrieval_realm.py
Normal file
162
src/transformers/models/realm/retrieval_realm.py
Normal file
@ -0,0 +1,162 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""REALM Retriever model implementation."""
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from ...utils import logging
|
||||
from .tokenization_realm import RealmTokenizer
|
||||
|
||||
|
||||
_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def convert_tfrecord_to_np(block_records_path: str, num_block_records: int) -> np.ndarray:
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
blocks_dataset = tf.data.TFRecordDataset(block_records_path, buffer_size=512 * 1024 * 1024)
|
||||
blocks_dataset = blocks_dataset.batch(num_block_records, drop_remainder=True)
|
||||
np_record = next(blocks_dataset.take(1).as_numpy_iterator())
|
||||
|
||||
return np_record
|
||||
|
||||
|
||||
class ScaNNSearcher:
|
||||
"""Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db,
|
||||
num_neighbors,
|
||||
dimensions_per_block=2,
|
||||
num_leaves=1000,
|
||||
num_leaves_to_search=100,
|
||||
training_sample_size=100000,
|
||||
):
|
||||
"""Build scann searcher."""
|
||||
|
||||
from scann.scann_ops.py.scann_ops_pybind import builder as Builder
|
||||
|
||||
builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product")
|
||||
builder = builder.tree(
|
||||
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size
|
||||
)
|
||||
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
|
||||
|
||||
self.searcher = builder.build()
|
||||
|
||||
def search_batched(self, question_projection):
|
||||
retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu())
|
||||
return retrieved_block_ids.astype("int64")
|
||||
|
||||
|
||||
class RealmRetriever:
|
||||
"""The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer
|
||||
positions."
|
||||
|
||||
Parameters:
|
||||
block_records (`np.ndarray`):
|
||||
A numpy array which cantains evidence texts.
|
||||
tokenizer ([`RealmTokenizer`]):
|
||||
The tokenizer to encode retrieved texts.
|
||||
"""
|
||||
|
||||
def __init__(self, block_records, tokenizer):
|
||||
super().__init__()
|
||||
self.block_records = block_records
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"):
|
||||
retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0)
|
||||
|
||||
question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True)
|
||||
|
||||
text = []
|
||||
text_pair = []
|
||||
for retrieved_block in retrieved_blocks:
|
||||
text.append(question)
|
||||
text_pair.append(retrieved_block.decode())
|
||||
|
||||
concat_inputs = self.tokenizer(text, text_pair, padding=True, truncation=True, max_length=max_length)
|
||||
concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
if answer_ids is not None:
|
||||
return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,)
|
||||
else:
|
||||
return (None, None, None, concat_inputs_tensors)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs):
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME)
|
||||
else:
|
||||
block_records_path = hf_hub_download(
|
||||
repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs
|
||||
)
|
||||
block_records = np.load(block_records_path, allow_pickle=True)
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
|
||||
|
||||
return cls(block_records, tokenizer)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
# save block records
|
||||
np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records)
|
||||
# save tokenizer
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
def block_has_answer(self, concat_inputs, answer_ids):
|
||||
"""check if retrieved_blocks has answers."""
|
||||
has_answers = []
|
||||
start_pos = []
|
||||
end_pos = []
|
||||
max_answers = 0
|
||||
|
||||
for input_id in concat_inputs.input_ids:
|
||||
start_pos.append([])
|
||||
end_pos.append([])
|
||||
input_id_list = input_id.tolist()
|
||||
# Checking answers after the [SEP] token
|
||||
sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
|
||||
for answer in answer_ids:
|
||||
for idx in range(sep_idx, len(input_id)):
|
||||
if answer[0] == input_id_list[idx]:
|
||||
if input_id_list[idx : idx + len(answer)] == answer:
|
||||
start_pos[-1].append(idx)
|
||||
end_pos[-1].append(idx + len(answer) - 1)
|
||||
|
||||
if len(start_pos[-1]) == 0:
|
||||
has_answers.append(False)
|
||||
else:
|
||||
has_answers.append(True)
|
||||
if len(start_pos[-1]) > max_answers:
|
||||
max_answers = len(start_pos[-1])
|
||||
|
||||
# Pad -1 to max_answers
|
||||
for start_pos_, end_pos_ in zip(start_pos, end_pos):
|
||||
if len(start_pos_) < max_answers:
|
||||
padded = [-1] * (max_answers - len(start_pos_))
|
||||
start_pos_ += padded
|
||||
end_pos_ += padded
|
||||
|
||||
return has_answers, start_pos, end_pos
|
149
src/transformers/models/realm/tokenization_realm.py
Normal file
149
src/transformers/models/realm/tokenization_realm.py
Normal file
@ -0,0 +1,149 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for REALM."""
|
||||
|
||||
from ...file_utils import PaddingStrategy
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..bert.tokenization_bert import BertTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"realm-cc-news-pretrained-embedder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-embedder/resolve/main/vocab.txt",
|
||||
"realm-cc-news-pretrained-encoder": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-encoder/resolve/main/vocab.txt",
|
||||
"realm-cc-news-pretrained-scorer": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-scorer/resolve/main/vocab.txt",
|
||||
"realm-cc-news-pretrained-openqa": "https://huggingface.co/qqaatw/realm-cc-news-pretrained-openqa/aresolve/main/vocab.txt",
|
||||
"realm-orqa-nq-openqa": "https://huggingface.co/qqaatw/realm-orqa-nq-openqa/resolve/main/vocab.txt",
|
||||
"realm-orqa-nq-reader": "https://huggingface.co/qqaatw/realm-orqa-nq-reader/resolve/main/vocab.txt",
|
||||
"realm-orqa-wq-openqa": "https://huggingface.co/qqaatw/realm-orqa-wq-openqa/resolve/main/vocab.txt",
|
||||
"realm-orqa-wq-reader": "https://huggingface.co/qqaatw/realm-orqa-wq-reader/resolve/main/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"realm-cc-news-pretrained-embedder": 512,
|
||||
"realm-cc-news-pretrained-encoder": 512,
|
||||
"realm-cc-news-pretrained-scorer": 512,
|
||||
"realm-cc-news-pretrained-openqa": 512,
|
||||
"realm-orqa-nq-openqa": 512,
|
||||
"realm-orqa-nq-reader": 512,
|
||||
"realm-orqa-wq-openqa": 512,
|
||||
"realm-orqa-wq-reader": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"realm-cc-news-pretrained-embedder": {"do_lower_case": True},
|
||||
"realm-cc-news-pretrained-encoder": {"do_lower_case": True},
|
||||
"realm-cc-news-pretrained-scorer": {"do_lower_case": True},
|
||||
"realm-cc-news-pretrained-openqa": {"do_lower_case": True},
|
||||
"realm-orqa-nq-openqa": {"do_lower_case": True},
|
||||
"realm-orqa-nq-reader": {"do_lower_case": True},
|
||||
"realm-orqa-wq-openqa": {"do_lower_case": True},
|
||||
"realm-orqa-wq-reader": {"do_lower_case": True},
|
||||
}
|
||||
|
||||
|
||||
class RealmTokenizer(BertTokenizer):
|
||||
r"""
|
||||
Construct a REALM tokenizer.
|
||||
|
||||
[`RealmTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting and
|
||||
wordpiece.
|
||||
|
||||
Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
|
||||
def batch_encode_candidates(self, text, **kwargs):
|
||||
r"""
|
||||
Encode a batch of text or text pair. This method is similar to regular __call__ method but has the following
|
||||
differences:
|
||||
|
||||
1. Handle additional num_candidate axis. (batch_size, num_candidates, text)
|
||||
2. Always pad the sequences to *max_length*.
|
||||
3. Must specify *max_length* in order to stack packs of candidates into a batch.
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
text (`List[List[str]]`):
|
||||
The batch of sequences to be encoded. Each sequence must be in this format: (batch_size,
|
||||
num_candidates, text).
|
||||
text_pair (`List[List[str]]`, *optional*):
|
||||
The batch of sequences to be encoded. Each sequence must be in this format: (batch_size,
|
||||
num_candidates, text).
|
||||
**kwargs:
|
||||
Keyword arguments of the __call__ method.
|
||||
|
||||
Returns:
|
||||
[`BatchEncoding`]: Encoded text or text pair.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import RealmTokenizer
|
||||
|
||||
>>> # batch_size = 2, num_candidates = 2
|
||||
>>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
|
||||
|
||||
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
|
||||
```"""
|
||||
|
||||
# Always using a fixed sequence length to encode in order to stack candidates into a batch.
|
||||
kwargs["padding"] = PaddingStrategy.MAX_LENGTH
|
||||
|
||||
batch_text = text
|
||||
batch_text_pair = kwargs.pop("text_pair", None)
|
||||
return_tensors = kwargs.pop("return_tensors", None)
|
||||
|
||||
output_data = {
|
||||
"input_ids": [],
|
||||
"attention_mask": [],
|
||||
"token_type_ids": [],
|
||||
}
|
||||
|
||||
for idx, candidate_text in enumerate(batch_text):
|
||||
if batch_text_pair is not None:
|
||||
candidate_text_pair = batch_text_pair[idx]
|
||||
else:
|
||||
candidate_text_pair = None
|
||||
|
||||
encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs)
|
||||
|
||||
encoded_input_ids = encoded_candidates.get("input_ids")
|
||||
encoded_attention_mask = encoded_candidates.get("attention_mask")
|
||||
encoded_token_type_ids = encoded_candidates.get("token_type_ids")
|
||||
|
||||
if encoded_input_ids is not None:
|
||||
output_data["input_ids"].append(encoded_input_ids)
|
||||
if encoded_attention_mask is not None:
|
||||
output_data["attention_mask"].append(encoded_attention_mask)
|
||||
if encoded_token_type_ids is not None:
|
||||
output_data["token_type_ids"].append(encoded_token_type_ids)
|
||||
|
||||
output_data = dict((key, item) for key, item in output_data.items() if len(item) != 0)
|
||||
|
||||
return BatchEncoding(output_data, tensor_type=return_tensors)
|
@ -2783,6 +2783,62 @@ class RagTokenForGeneration(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
REALM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class RealmEmbedder(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RealmForOpenQA(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RealmKnowledgeAugEncoder(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RealmPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RealmReader(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RealmRetriever(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class RealmScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def load_tf_weights_in_realm(*args, **kwargs):
|
||||
requires_backends(load_tf_weights_in_realm, ["torch"])
|
||||
|
||||
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
545
tests/test_modeling_realm.py
Normal file
545
tests/test_modeling_realm.py
Normal file
@ -0,0 +1,545 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch REALM model. """
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import RealmConfig, is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
RealmEmbedder,
|
||||
RealmForOpenQA,
|
||||
RealmKnowledgeAugEncoder,
|
||||
RealmReader,
|
||||
RealmRetriever,
|
||||
RealmScorer,
|
||||
RealmTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class RealmModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
retriever_proj_size=128,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
span_hidden_size=50,
|
||||
max_span_width=10,
|
||||
reader_layer_norm_eps=1e-3,
|
||||
reader_beam_size=4,
|
||||
reader_seq_len=288 + 32,
|
||||
num_block_records=13353718,
|
||||
searcher_beam_size=8,
|
||||
searcher_seq_len=64,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
num_candidates=10,
|
||||
scope=None,
|
||||
):
|
||||
# General config
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.retriever_proj_size = retriever_proj_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
# Reader config
|
||||
self.span_hidden_size = span_hidden_size
|
||||
self.max_span_width = max_span_width
|
||||
self.reader_layer_norm_eps = reader_layer_norm_eps
|
||||
self.reader_beam_size = reader_beam_size
|
||||
self.reader_seq_len = reader_seq_len
|
||||
|
||||
# Searcher config
|
||||
self.num_block_records = num_block_records
|
||||
self.searcher_beam_size = searcher_beam_size
|
||||
self.searcher_seq_len = searcher_seq_len
|
||||
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.num_candidates = num_candidates
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
candiate_input_ids = ids_tensor([self.batch_size, self.num_candidates, self.seq_length], self.vocab_size)
|
||||
reader_input_ids = ids_tensor([self.reader_beam_size, self.reader_seq_len], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
candiate_input_mask = None
|
||||
reader_input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
candiate_input_mask = random_attention_mask([self.batch_size, self.num_candidates, self.seq_length])
|
||||
reader_input_mask = random_attention_mask([self.reader_beam_size, self.reader_seq_len])
|
||||
|
||||
token_type_ids = None
|
||||
candidate_token_type_ids = None
|
||||
reader_token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
candidate_token_type_ids = ids_tensor(
|
||||
[self.batch_size, self.num_candidates, self.seq_length], self.type_vocab_size
|
||||
)
|
||||
reader_token_type_ids = ids_tensor([self.reader_beam_size, self.reader_seq_len], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
# inputs with additional num_candidates axis.
|
||||
scorer_encoder_inputs = (candiate_input_ids, candiate_input_mask, candidate_token_type_ids)
|
||||
# reader inputs
|
||||
reader_inputs = (reader_input_ids, reader_input_mask, reader_token_type_ids)
|
||||
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
scorer_encoder_inputs,
|
||||
reader_inputs,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
return RealmConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
retriever_proj_size=self.retriever_proj_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
num_candidates=self.num_candidates,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
def create_and_check_embedder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
scorer_encoder_inputs,
|
||||
reader_inputs,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
model = RealmEmbedder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
self.parent.assertEqual(result.projected_score.shape, (self.batch_size, self.retriever_proj_size))
|
||||
|
||||
def create_and_check_encoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
scorer_encoder_inputs,
|
||||
reader_inputs,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
model = RealmKnowledgeAugEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
relevance_score = floats_tensor([self.batch_size, self.num_candidates])
|
||||
result = model(
|
||||
scorer_encoder_inputs[0],
|
||||
attention_mask=scorer_encoder_inputs[1],
|
||||
token_type_ids=scorer_encoder_inputs[2],
|
||||
relevance_score=relevance_score,
|
||||
labels=token_labels,
|
||||
)
|
||||
self.parent.assertEqual(
|
||||
result.logits.shape, (self.batch_size * self.num_candidates, self.seq_length, self.vocab_size)
|
||||
)
|
||||
|
||||
def create_and_check_reader(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
scorer_encoder_inputs,
|
||||
reader_inputs,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
model = RealmReader(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
relevance_score = floats_tensor([self.reader_beam_size])
|
||||
result = model(
|
||||
reader_inputs[0],
|
||||
attention_mask=reader_inputs[1],
|
||||
token_type_ids=reader_inputs[2],
|
||||
relevance_score=relevance_score,
|
||||
)
|
||||
self.parent.assertEqual(result.block_idx.shape, ())
|
||||
self.parent.assertEqual(result.candidate.shape, ())
|
||||
self.parent.assertEqual(result.start_pos.shape, ())
|
||||
self.parent.assertEqual(result.end_pos.shape, ())
|
||||
|
||||
def create_and_check_scorer(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
scorer_encoder_inputs,
|
||||
reader_inputs,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
):
|
||||
model = RealmScorer(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
candidate_input_ids=scorer_encoder_inputs[0],
|
||||
candidate_attention_mask=scorer_encoder_inputs[1],
|
||||
candidate_token_type_ids=scorer_encoder_inputs[2],
|
||||
)
|
||||
self.parent.assertEqual(result.relevance_score.shape, (self.batch_size, self.num_candidates))
|
||||
self.parent.assertEqual(result.query_score.shape, (self.batch_size, self.retriever_proj_size))
|
||||
self.parent.assertEqual(
|
||||
result.candidate_score.shape, (self.batch_size, self.num_candidates, self.retriever_proj_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
scorer_encoder_inputs,
|
||||
reader_inputs,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class RealmModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
RealmEmbedder,
|
||||
RealmKnowledgeAugEncoder,
|
||||
# RealmScorer is excluded from common tests as it is a container model
|
||||
# consisting of two RealmEmbedders & a simple inner product calculation.
|
||||
# RealmScorer
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = ()
|
||||
|
||||
# disable these tests because there is no base_model in Realm
|
||||
test_save_load_fast_init_from_base = False
|
||||
test_save_load_fast_init_to_base = False
|
||||
|
||||
def setUp(self):
|
||||
self.test_pruning = False
|
||||
self.model_tester = RealmModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=RealmConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_embedder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_embedder(*config_and_inputs)
|
||||
|
||||
def test_encoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_encoder(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_embedder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_encoder(*config_and_inputs)
|
||||
|
||||
def test_retriever(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_scorer(*config_and_inputs)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
return
|
||||
|
||||
config, *inputs = self.model_tester.prepare_config_and_inputs()
|
||||
input_ids, token_type_ids, input_mask, scorer_encoder_inputs = inputs[0:4]
|
||||
config.return_dict = True
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
|
||||
# RealmKnowledgeAugEncoder training
|
||||
model = RealmKnowledgeAugEncoder(config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": scorer_encoder_inputs[0].to(torch_device),
|
||||
"attention_mask": scorer_encoder_inputs[1].to(torch_device),
|
||||
"token_type_ids": scorer_encoder_inputs[2].to(torch_device),
|
||||
"relevance_score": floats_tensor([self.model_tester.batch_size, self.model_tester.num_candidates]),
|
||||
}
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs = inputs_dict
|
||||
loss = model(**inputs).loss
|
||||
loss.backward()
|
||||
|
||||
# RealmForOpenQA training
|
||||
openqa_config = copy.deepcopy(config)
|
||||
openqa_config.vocab_size = 30522 # the retrieved texts will inevitably have more than 99 vocabs.
|
||||
openqa_config.num_block_records = 5
|
||||
openqa_config.searcher_beam_size = 2
|
||||
|
||||
block_records = np.array(
|
||||
[
|
||||
b"This is the first record.",
|
||||
b"This is the second record.",
|
||||
b"This is the third record.",
|
||||
b"This is the fourth record.",
|
||||
b"This is the fifth record.",
|
||||
],
|
||||
dtype=np.object,
|
||||
)
|
||||
retriever = RealmRetriever(block_records, tokenizer)
|
||||
model = RealmForOpenQA(openqa_config, retriever)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids[:1].to(torch_device),
|
||||
"attention_mask": input_mask[:1].to(torch_device),
|
||||
"token_type_ids": token_type_ids[:1].to(torch_device),
|
||||
"answer_ids": input_ids[:1].tolist(),
|
||||
}
|
||||
inputs = self._prepare_for_class(inputs_dict, RealmForOpenQA)
|
||||
loss = model(**inputs).reader_output.loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
def test_embedder_from_pretrained(self):
|
||||
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_encoder_from_pretrained(self):
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained("qqaatw/realm-cc-news-pretrained-encoder")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_open_qa_from_pretrained(self):
|
||||
model = RealmForOpenQA.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_reader_from_pretrained(self):
|
||||
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_scorer_from_pretrained(self):
|
||||
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class RealmModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_embedder(self):
|
||||
retriever_projected_size = 128
|
||||
|
||||
model = RealmEmbedder.from_pretrained("qqaatw/realm-cc-news-pretrained-embedder")
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, retriever_projected_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([[-0.0714, -0.0837, -0.1314]])
|
||||
self.assertTrue(torch.allclose(output[:, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_encoder(self):
|
||||
num_candidates = 2
|
||||
vocab_size = 30522
|
||||
|
||||
model = RealmKnowledgeAugEncoder.from_pretrained(
|
||||
"qqaatw/realm-cc-news-pretrained-encoder", num_candidates=num_candidates
|
||||
)
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
|
||||
relevance_score = torch.tensor([[0.3, 0.7]], dtype=torch.float32)
|
||||
output = model(input_ids, relevance_score=relevance_score)[0]
|
||||
|
||||
expected_shape = torch.Size((2, 6, vocab_size))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([[[-11.0888, -11.2544], [-10.2170, -10.3874]]])
|
||||
|
||||
self.assertTrue(torch.allclose(output[1, :2, :2], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_open_qa(self):
|
||||
from transformers.models.realm.retrieval_realm import RealmRetriever
|
||||
|
||||
config = RealmConfig()
|
||||
|
||||
tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
retriever = RealmRetriever.from_pretrained("qqaatw/realm-orqa-nq-openqa")
|
||||
|
||||
model = RealmForOpenQA.from_pretrained(
|
||||
"qqaatw/realm-orqa-nq-openqa",
|
||||
retriever=retriever,
|
||||
config=config,
|
||||
)
|
||||
|
||||
question = "Who is the pioneer in modern computer science?"
|
||||
|
||||
question = tokenizer(
|
||||
[question],
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=model.config.searcher_seq_len,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
predicted_answer_ids = model(**question).predicted_answer_ids
|
||||
|
||||
predicted_answer = tokenizer.decode(predicted_answer_ids)
|
||||
self.assertEqual(predicted_answer, "alan mathison turing")
|
||||
|
||||
@slow
|
||||
def test_inference_reader(self):
|
||||
config = RealmConfig(reader_beam_size=2, max_span_width=3)
|
||||
model = RealmReader.from_pretrained("qqaatw/realm-orqa-nq-reader", config=config)
|
||||
|
||||
concat_input_ids = torch.arange(10).view((2, 5))
|
||||
concat_token_type_ids = torch.tensor([[0, 0, 1, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.int64)
|
||||
relevance_score = torch.tensor([0.3, 0.7], dtype=torch.float32)
|
||||
|
||||
output = model(
|
||||
concat_input_ids, token_type_ids=concat_token_type_ids, relevance_score=relevance_score, return_dict=True
|
||||
)
|
||||
|
||||
block_idx_expected_shape = torch.Size(())
|
||||
start_pos_expected_shape = torch.Size((1,))
|
||||
end_pos_expected_shape = torch.Size((1,))
|
||||
self.assertEqual(output.block_idx.shape, block_idx_expected_shape)
|
||||
self.assertEqual(output.start_pos.shape, start_pos_expected_shape)
|
||||
self.assertEqual(output.end_pos.shape, end_pos_expected_shape)
|
||||
|
||||
expected_block_idx = torch.tensor(1)
|
||||
expected_start_pos = torch.tensor(3)
|
||||
expected_end_pos = torch.tensor(3)
|
||||
|
||||
self.assertTrue(torch.allclose(output.block_idx, expected_block_idx, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output.start_pos, expected_start_pos, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(output.end_pos, expected_end_pos, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_inference_scorer(self):
|
||||
num_candidates = 2
|
||||
|
||||
model = RealmScorer.from_pretrained("qqaatw/realm-cc-news-pretrained-scorer", num_candidates=num_candidates)
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
candidate_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])
|
||||
output = model(input_ids, candidate_input_ids=candidate_input_ids)[0]
|
||||
|
||||
expected_shape = torch.Size((1, 2))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([[0.7410, 0.7170]])
|
||||
self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))
|
185
tests/test_retrieval_realm.py
Normal file
185
tests/test_retrieval_realm.py
Normal file
@ -0,0 +1,185 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from transformers.models.realm.configuration_realm import RealmConfig
|
||||
from transformers.models.realm.retrieval_realm import _REALM_BLOCK_RECORDS_FILENAME, RealmRetriever
|
||||
from transformers.models.realm.tokenization_realm import VOCAB_FILES_NAMES, RealmTokenizer
|
||||
|
||||
|
||||
class RealmRetrieverTest(TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
self.num_block_records = 5
|
||||
|
||||
# Realm tok
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"test",
|
||||
"question",
|
||||
"this",
|
||||
"is",
|
||||
"the",
|
||||
"first",
|
||||
"second",
|
||||
"third",
|
||||
"fourth",
|
||||
"fifth",
|
||||
"record",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
realm_tokenizer_path = os.path.join(self.tmpdirname, "realm_tokenizer")
|
||||
os.makedirs(realm_tokenizer_path, exist_ok=True)
|
||||
self.vocab_file = os.path.join(realm_tokenizer_path, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
realm_block_records_path = os.path.join(self.tmpdirname, "realm_block_records")
|
||||
os.makedirs(realm_block_records_path, exist_ok=True)
|
||||
|
||||
def get_tokenizer(self) -> RealmTokenizer:
|
||||
return RealmTokenizer.from_pretrained(os.path.join(self.tmpdirname, "realm_tokenizer"))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_config(self):
|
||||
config = RealmConfig(num_block_records=self.num_block_records)
|
||||
return config
|
||||
|
||||
def get_dummy_dataset(self):
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"id": ["0", "1"],
|
||||
"question": ["foo", "bar"],
|
||||
"answers": [["Foo", "Bar"], ["Bar"]],
|
||||
}
|
||||
)
|
||||
return dataset
|
||||
|
||||
def get_dummy_block_records(self):
|
||||
block_records = np.array(
|
||||
[
|
||||
b"This is the first record",
|
||||
b"This is the second record",
|
||||
b"This is the third record",
|
||||
b"This is the fourth record",
|
||||
b"This is the fifth record",
|
||||
],
|
||||
dtype=np.object,
|
||||
)
|
||||
return block_records
|
||||
|
||||
def get_dummy_retriever(self):
|
||||
retriever = RealmRetriever(
|
||||
block_records=self.get_dummy_block_records(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
)
|
||||
return retriever
|
||||
|
||||
def test_retrieve(self):
|
||||
config = self.get_config()
|
||||
retriever = self.get_dummy_retriever()
|
||||
tokenizer = retriever.tokenizer
|
||||
|
||||
retrieved_block_ids = np.array([0, 3], dtype=np.long)
|
||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||
answer_ids = tokenizer(
|
||||
["the fourth"],
|
||||
add_special_tokens=False,
|
||||
return_token_type_ids=False,
|
||||
return_attention_mask=False,
|
||||
).input_ids
|
||||
max_length = config.reader_seq_len
|
||||
|
||||
has_answers, start_pos, end_pos, concat_inputs = retriever(
|
||||
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
|
||||
)
|
||||
|
||||
self.assertEqual(len(has_answers), 2)
|
||||
self.assertEqual(len(start_pos), 2)
|
||||
self.assertEqual(len(end_pos), 2)
|
||||
self.assertEqual(concat_inputs.input_ids.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.attention_mask.shape, (2, 10))
|
||||
self.assertEqual(concat_inputs.token_type_ids.shape, (2, 10))
|
||||
self.assertEqual(
|
||||
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[0]),
|
||||
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "first", "record", "[SEP]"],
|
||||
)
|
||||
self.assertEqual(
|
||||
tokenizer.convert_ids_to_tokens(concat_inputs.input_ids[1]),
|
||||
["[CLS]", "test", "question", "[SEP]", "this", "is", "the", "fourth", "record", "[SEP]"],
|
||||
)
|
||||
|
||||
def test_block_has_answer(self):
|
||||
config = self.get_config()
|
||||
retriever = self.get_dummy_retriever()
|
||||
tokenizer = retriever.tokenizer
|
||||
|
||||
retrieved_block_ids = np.array([0, 3], dtype=np.long)
|
||||
question_input_ids = tokenizer(["Test question"]).input_ids
|
||||
answer_ids = tokenizer(
|
||||
["the fourth"],
|
||||
add_special_tokens=False,
|
||||
return_token_type_ids=False,
|
||||
return_attention_mask=False,
|
||||
).input_ids
|
||||
max_length = config.reader_seq_len
|
||||
|
||||
has_answers, start_pos, end_pos, _ = retriever(
|
||||
retrieved_block_ids, question_input_ids, answer_ids=answer_ids, max_length=max_length, return_tensors="np"
|
||||
)
|
||||
|
||||
self.assertEqual([False, True], has_answers)
|
||||
self.assertEqual([[-1], [6]], start_pos)
|
||||
self.assertEqual([[-1], [7]], end_pos)
|
||||
|
||||
def test_save_load_pretrained(self):
|
||||
retriever = self.get_dummy_retriever()
|
||||
retriever.save_pretrained(os.path.join(self.tmpdirname, "realm_block_records"))
|
||||
|
||||
# Test local path
|
||||
retriever = retriever.from_pretrained(os.path.join(self.tmpdirname, "realm_block_records"))
|
||||
self.assertEqual(retriever.block_records[0], b"This is the first record")
|
||||
|
||||
# Test mocked remote path
|
||||
with patch("transformers.models.realm.retrieval_realm.hf_hub_download") as mock_hf_hub_download:
|
||||
mock_hf_hub_download.return_value = os.path.join(
|
||||
os.path.join(self.tmpdirname, "realm_block_records"), _REALM_BLOCK_RECORDS_FILENAME
|
||||
)
|
||||
retriever = RealmRetriever.from_pretrained("qqaatw/realm-cc-news-pretrained-openqa")
|
||||
|
||||
self.assertEqual(retriever.block_records[0], b"This is the first record")
|
314
tests/test_tokenization_realm.py
Normal file
314
tests/test_tokenization_realm.py
Normal file
@ -0,0 +1,314 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.models.bert.tokenization_bert import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BasicTokenizer,
|
||||
WordpieceTokenizer,
|
||||
_is_control,
|
||||
_is_punctuation,
|
||||
_is_whitespace,
|
||||
)
|
||||
from transformers.models.realm.tokenization_realm import RealmTokenizer
|
||||
from transformers.testing_utils import require_tokenizers, slow
|
||||
|
||||
from .test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
|
||||
|
||||
@require_tokenizers
|
||||
class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = RealmTokenizer
|
||||
rust_tokenizer_class = None
|
||||
test_rust_tokenizer = False
|
||||
space_between_special_tokens = True
|
||||
from_pretrained_filter = filter_non_english
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
vocab_tokens = [
|
||||
"[UNK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[PAD]",
|
||||
"[MASK]",
|
||||
"want",
|
||||
"##want",
|
||||
"##ed",
|
||||
"wa",
|
||||
"un",
|
||||
"runn",
|
||||
"##ing",
|
||||
",",
|
||||
"low",
|
||||
"lowest",
|
||||
]
|
||||
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||||
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "UNwant\u00E9d,running"
|
||||
output_text = "unwanted, running"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize("UNwant\u00E9d,running")
|
||||
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
|
||||
|
||||
def test_rust_and_python_full_tokenizers(self):
|
||||
if not self.test_rust_tokenizer:
|
||||
return
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
sequence = "UNwant\u00E9d,running"
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
# With lower casing
|
||||
tokenizer = self.get_tokenizer(do_lower_case=True)
|
||||
rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True)
|
||||
|
||||
sequence = "UNwant\u00E9d,running"
|
||||
|
||||
tokens = tokenizer.tokenize(sequence)
|
||||
rust_tokens = rust_tokenizer.tokenize(sequence)
|
||||
self.assertListEqual(tokens, rust_tokens)
|
||||
|
||||
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
||||
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
ids = tokenizer.encode(sequence)
|
||||
rust_ids = rust_tokenizer.encode(sequence)
|
||||
self.assertListEqual(ids, rust_ids)
|
||||
|
||||
def test_chinese(self):
|
||||
tokenizer = BasicTokenizer()
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"])
|
||||
|
||||
def test_basic_tokenizer_lower(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"]
|
||||
)
|
||||
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_basic_tokenizer_lower_strip_accents_false(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hällo", "!", "how", "are", "you", "?"]
|
||||
)
|
||||
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
|
||||
|
||||
def test_basic_tokenizer_lower_strip_accents_true(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
|
||||
)
|
||||
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_basic_tokenizer_lower_strip_accents_default(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
|
||||
)
|
||||
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_basic_tokenizer_no_lower(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
|
||||
)
|
||||
|
||||
def test_basic_tokenizer_no_lower_strip_accents_false(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
|
||||
)
|
||||
|
||||
def test_basic_tokenizer_no_lower_strip_accents_true(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
|
||||
)
|
||||
|
||||
def test_basic_tokenizer_respects_never_split_tokens(self):
|
||||
tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
|
||||
)
|
||||
|
||||
def test_wordpiece_tokenizer(self):
|
||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
|
||||
|
||||
vocab = {}
|
||||
for (i, token) in enumerate(vocab_tokens):
|
||||
vocab[token] = i
|
||||
tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]")
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize(""), [])
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"])
|
||||
|
||||
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
|
||||
|
||||
def test_is_whitespace(self):
|
||||
self.assertTrue(_is_whitespace(" "))
|
||||
self.assertTrue(_is_whitespace("\t"))
|
||||
self.assertTrue(_is_whitespace("\r"))
|
||||
self.assertTrue(_is_whitespace("\n"))
|
||||
self.assertTrue(_is_whitespace("\u00A0"))
|
||||
|
||||
self.assertFalse(_is_whitespace("A"))
|
||||
self.assertFalse(_is_whitespace("-"))
|
||||
|
||||
def test_is_control(self):
|
||||
self.assertTrue(_is_control("\u0005"))
|
||||
|
||||
self.assertFalse(_is_control("A"))
|
||||
self.assertFalse(_is_control(" "))
|
||||
self.assertFalse(_is_control("\t"))
|
||||
self.assertFalse(_is_control("\r"))
|
||||
|
||||
def test_is_punctuation(self):
|
||||
self.assertTrue(_is_punctuation("-"))
|
||||
self.assertTrue(_is_punctuation("$"))
|
||||
self.assertTrue(_is_punctuation("`"))
|
||||
self.assertTrue(_is_punctuation("."))
|
||||
|
||||
self.assertFalse(_is_punctuation("A"))
|
||||
self.assertFalse(_is_punctuation(" "))
|
||||
|
||||
def test_clean_text(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Example taken from the issue https://github.com/huggingface/tokenizers/issues/340
|
||||
self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]])
|
||||
|
||||
if self.test_rust_tokenizer:
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
self.assertListEqual(
|
||||
[rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||
|
||||
text = tokenizer.encode("sequence builders", add_special_tokens=False)
|
||||
text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
|
||||
|
||||
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
|
||||
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
|
||||
|
||||
assert encoded_sentence == [101] + text + [102]
|
||||
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
||||
|
||||
def test_offsets_with_special_characters(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
sentence = f"A, naïve {tokenizer_r.mask_token} AllenNLP sentence."
|
||||
tokens = tokenizer_r.encode_plus(
|
||||
sentence,
|
||||
return_attention_mask=False,
|
||||
return_token_type_ids=False,
|
||||
return_offsets_mapping=True,
|
||||
add_special_tokens=True,
|
||||
)
|
||||
|
||||
do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False
|
||||
expected_results = (
|
||||
[
|
||||
((0, 0), tokenizer_r.cls_token),
|
||||
((0, 1), "A"),
|
||||
((1, 2), ","),
|
||||
((3, 5), "na"),
|
||||
((5, 6), "##ï"),
|
||||
((6, 8), "##ve"),
|
||||
((9, 15), tokenizer_r.mask_token),
|
||||
((16, 21), "Allen"),
|
||||
((21, 23), "##NL"),
|
||||
((23, 24), "##P"),
|
||||
((25, 33), "sentence"),
|
||||
((33, 34), "."),
|
||||
((0, 0), tokenizer_r.sep_token),
|
||||
]
|
||||
if not do_lower_case
|
||||
else [
|
||||
((0, 0), tokenizer_r.cls_token),
|
||||
((0, 1), "a"),
|
||||
((1, 2), ","),
|
||||
((3, 8), "naive"),
|
||||
((9, 15), tokenizer_r.mask_token),
|
||||
((16, 21), "allen"),
|
||||
((21, 23), "##nl"),
|
||||
((23, 24), "##p"),
|
||||
((25, 33), "sentence"),
|
||||
((33, 34), "."),
|
||||
((0, 0), tokenizer_r.sep_token),
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"])
|
||||
)
|
||||
self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])
|
||||
|
||||
@slow
|
||||
def test_batch_encode_candidates(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||
|
||||
text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
|
||||
|
||||
encoded_sentence = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
|
||||
|
||||
expected_shape = (2, 2, 10)
|
||||
|
||||
assert encoded_sentence["input_ids"].shape == expected_shape
|
||||
assert encoded_sentence["attention_mask"].shape == expected_shape
|
||||
assert encoded_sentence["token_type_ids"].shape == expected_shape
|
@ -35,6 +35,7 @@ PATH_TO_DOC = "docs/source"
|
||||
# Update this list with models that are supposed to be private.
|
||||
PRIVATE_MODELS = [
|
||||
"DPRSpanPredictor",
|
||||
"RealmBertModel",
|
||||
"T5Stack",
|
||||
"TFDPRSpanPredictor",
|
||||
]
|
||||
@ -73,6 +74,10 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"RealmBertModel", # Building part of bigger (tested) model.
|
||||
"RealmReader", # Not regular model.
|
||||
"RealmScorer", # Not regular model.
|
||||
"RealmForOpenQA", # Not regular model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
@ -129,6 +134,10 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"RagModel",
|
||||
"RagSequenceForGeneration",
|
||||
"RagTokenForGeneration",
|
||||
"RealmEmbedder",
|
||||
"RealmForOpenQA",
|
||||
"RealmScorer",
|
||||
"RealmReader",
|
||||
"TFDPRReader",
|
||||
"TFGPT2DoubleHeadsModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
|
Loading…
Reference in New Issue
Block a user