mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[WIP] Tapas v4 (tres) (#9117)
* First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Test PyTorch scatter * Set to slow + minify * Calm flake8 down * First commit: adding all files from tapas_v3 * Fix multiple bugs including soft dependency and new structure of the library * Improve testing by adding torch_device to inputs and adding dependency on scatter * Use Python 3 inheritance rather than Python 2 * First draft model cards of base sized models * Remove model cards as they are already on the hub * Fix multiple bugs with integration tests * All model integration tests pass * Remove print statement * Add test for convert_logits_to_predictions method of TapasTokenizer * Incorporate suggestions by Google authors * Fix remaining tests * Change position embeddings sizes to 512 instead of 1024 * Comment out positional embedding sizes * Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES * Added more model names * Fix truncation when no max length is specified * Disable torchscript test * Make style & make quality * Quality * Address CI needs * Test the Masked LM model * Fix the masked LM model * Truncate when overflowing * More much needed docs improvements * Fix some URLs * Some more docs improvements * Add add_pooling_layer argument to TapasModel Fix comments by @sgugger and @patrickvonplaten * Fix issue in docs + fix style and quality * Clean up conversion script and add task parameter to TapasConfig * Revert the task parameter of TapasConfig Some minor fixes * Improve conversion script and add test for absolute position embeddings * Improve conversion script and add test for absolute position embeddings * Fix bug with reset_position_index_per_cell arg of the conversion cli * Add notebooks to the examples directory and fix style and quality * Apply suggestions from code review * Move from `nielsr/` to `google/` namespace * Apply Sylvain's comments Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Rogge Niels <niels.rogge@howest.be> Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
parent
ad895af98d
commit
1551e2dc6d
@ -79,6 +79,7 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -105,6 +106,7 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
@ -183,6 +185,7 @@ jobs:
|
||||
- v0.4-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install .[sklearn,torch,testing,sentencepiece]
|
||||
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
|
||||
- save_cache:
|
||||
key: v0.4-torch-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
|
2
.github/workflows/self-push.yml
vendored
2
.github/workflows/self-push.yml
vendored
@ -50,6 +50,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
pip install pandas torch-scatter -f https://pytorch-geometric.com/whl/torch-$(python -c "import torch; print(''.join(torch.__version__)")+$(python -c "import torch; print(''.join(torch.version.cuda.split('.')))").html
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
run: |
|
||||
@ -187,6 +188,7 @@ jobs:
|
||||
pip install --upgrade pip
|
||||
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
|
||||
pip install git+https://github.com/huggingface/datasets
|
||||
pip install pandas torch-scatter -f https://pytorch-geometric.com/whl/torch-$(python -c "import torch; print(''.join(torch.__version__)")+$(python -c "import torch; print(''.join(torch.version.cuda.split('.')))").html
|
||||
|
||||
- name: Are GPUs recognized by our DL frameworks
|
||||
run: |
|
||||
|
@ -222,6 +222,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
ultilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT.
|
||||
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/master/model_doc/tapas.html)** released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
|
@ -176,19 +176,22 @@ and conversion utilities for the following models:
|
||||
30. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
31. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
31. `TAPAS <https://huggingface.co/transformers/master/model_doc/tapas.html>`__ released with the paper `TAPAS: Weakly
|
||||
Supervised Table Parsing via Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof
|
||||
Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
32. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
32. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
33. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
33. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
34. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
34. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
35. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
35. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
36. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
|
||||
@ -269,6 +272,8 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
@ -382,6 +387,7 @@ TensorFlow and/or Flax.
|
||||
model_doc/roberta
|
||||
model_doc/squeezebert
|
||||
model_doc/t5
|
||||
model_doc/tapas
|
||||
model_doc/transformerxl
|
||||
model_doc/xlm
|
||||
model_doc/xlmprophetnet
|
||||
|
427
docs/source/model_doc/tapas.rst
Normal file
427
docs/source/model_doc/tapas.rst
Normal file
@ -0,0 +1,427 @@
|
||||
TAPAS
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The TAPAS model was proposed in `TAPAS: Weakly Supervised Table Parsing via Pre-training
|
||||
<https://www.aclweb.org/anthology/2020.acl-main.398>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos. It's a BERT-based model specifically designed (and pre-trained) for
|
||||
answering questions about tabular data. Compared to BERT, TAPAS uses relative position embeddings and has 7 token types
|
||||
that encode tabular structure. TAPAS is pre-trained on the masked language modeling (MLM) objective on a large dataset
|
||||
comprising millions of tables from English Wikipedia and corresponding texts. For question answering, TAPAS has 2 heads
|
||||
on top: a cell selection head and an aggregation head, for (optionally) performing aggregations (such as counting or
|
||||
summing) among selected cells. TAPAS has been fine-tuned on several datasets: `SQA
|
||||
<https://www.microsoft.com/en-us/download/details.aspx?id=54253>`__ (Sequential Question Answering by Microsoft), `WTQ
|
||||
<https://github.com/ppasupat/WikiTableQuestions>`__ (Wiki Table Questions by Stanford University) and `WikiSQL
|
||||
<https://github.com/salesforce/WikiSQL>`__ (by Salesforce). It achieves state-of-the-art on both SQA and WTQ, while
|
||||
having comparable performance to SOTA on WikiSQL, with a much simpler architecture.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Answering natural language questions over tables is usually seen as a semantic parsing task. To alleviate the
|
||||
collection cost of full logical forms, one popular approach focuses on weak supervision consisting of denotations
|
||||
instead of logical forms. However, training semantic parsers from weak supervision poses difficulties, and in addition,
|
||||
the generated logical forms are only used as an intermediate step prior to retrieving the denotation. In this paper, we
|
||||
present TAPAS, an approach to question answering over tables without generating logical forms. TAPAS trains from weak
|
||||
supervision, and predicts the denotation by selecting table cells and optionally applying a corresponding aggregation
|
||||
operator to such selection. TAPAS extends BERT's architecture to encode tables as input, initializes from an effective
|
||||
joint pre-training of text segments and tables crawled from Wikipedia, and is trained end-to-end. We experiment with
|
||||
three different semantic parsing datasets, and find that TAPAS outperforms or rivals semantic parsing models by
|
||||
improving state-of-the-art accuracy on SQA from 55.1 to 67.2 and performing on par with the state-of-the-art on WIKISQL
|
||||
and WIKITQ, but with a simpler model architecture. We additionally find that transfer learning, which is trivial in our
|
||||
setting, from WIKISQL to WIKITQ, yields 48.7 accuracy, 4.2 points above the state-of-the-art.*
|
||||
|
||||
In addition, the authors have further pre-trained TAPAS to recognize **table entailment**, by creating a balanced
|
||||
dataset of millions of automatically created training examples which are learned in an intermediate step prior to
|
||||
fine-tuning. The authors of TAPAS call this further pre-training intermediate pre-training (since TAPAS is first
|
||||
pre-trained on MLM, and then on another dataset). They found that intermediate pre-training further improves
|
||||
performance on SQA, achieving a new state-of-the-art as well as state-of-the-art on `TabFact
|
||||
<https://github.com/wenhuchen/Table-Fact-Checking>`__, a large-scale dataset with 16k Wikipedia tables for table
|
||||
entailment (a binary classification task). For more details, see their follow-up paper: `Understanding tables with
|
||||
intermediate pre-training <https://www.aclweb.org/anthology/2020.findings-emnlp.27/>`__ by Julian Martin Eisenschlos,
|
||||
Syrine Krichene and Thomas Müller.
|
||||
|
||||
The original code can be found `here <https://github.com/google-research/tapas>`__.
|
||||
|
||||
Tips:
|
||||
|
||||
- TAPAS is a model that uses relative position embeddings by default (restarting the position embeddings at every cell
|
||||
of the table). Note that this is something that was added after the publication of the original TAPAS paper.
|
||||
According to the authors, this usually results in a slightly better performance, and allows you to encode longer
|
||||
sequences without running out of embeddings. This is reflected in the ``reset_position_index_per_cell`` parameter of
|
||||
:class:`~transformers.TapasConfig`, which is set to ``True`` by default. The default versions of the models available
|
||||
in the `model hub <https://huggingface.co/models?search=tapas>`_ all use relative position embeddings. You can still
|
||||
use the ones with absolute position embeddings by passing in an additional argument ``revision="no_reset"`` when
|
||||
calling the ``.from_pretrained()`` method. Note that it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
- TAPAS is based on BERT, so ``TAPAS-base`` for example corresponds to a ``BERT-base`` architecture. Of course,
|
||||
TAPAS-large will result in the best performance (the results reported in the paper are from TAPAS-large). Results of
|
||||
the various sized models are shown on the `original Github repository <https://github.com/google-research/tapas>`_.
|
||||
- TAPAS has checkpoints fine-tuned on SQA, which are capable of answering questions related to a table in a
|
||||
conversational set-up. This means that you can ask follow-up questions such as "what is his age?" related to the
|
||||
previous question. Note that the forward pass of TAPAS is a bit different in case of a conversational set-up: in that
|
||||
case, you have to feed every table-question pair one by one to the model, such that the `prev_labels` token type ids
|
||||
can be overwritten by the predicted `labels` of the model to the previous question. See "Usage" section for more
|
||||
info.
|
||||
- TAPAS is similar to BERT and therefore relies on the masked language modeling (MLM) objective. It is therefore
|
||||
efficient at predicting masked tokens and at NLU in general, but is not optimal for text generation. Models trained
|
||||
with a causal language modeling (CLM) objective are better in that regard.
|
||||
|
||||
|
||||
Usage: fine-tuning
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here we explain how you can fine-tune :class:`~transformers.TapasForQuestionAnswering` on your own dataset.
|
||||
|
||||
**STEP 1: Choose one of the 3 ways in which you can use TAPAS - or experiment**
|
||||
|
||||
Basically, there are 3 different ways in which one can fine-tune :class:`~transformers.TapasForQuestionAnswering`,
|
||||
corresponding to the different datasets on which Tapas was fine-tuned:
|
||||
|
||||
1. SQA: if you're interested in asking follow-up questions related to a table, in a conversational set-up. For example
|
||||
if you first ask "what's the name of the first actor?" then you can ask a follow-up question such as "how old is
|
||||
he?". Here, questions do not involve any aggregation (all questions are cell selection questions).
|
||||
2. WTQ: if you're not interested in asking questions in a conversational set-up, but rather just asking questions
|
||||
related to a table, which might involve aggregation, such as counting a number of rows, summing up cell values or
|
||||
averaging cell values. You can then for example ask "what's the total number of goals Cristiano Ronaldo made in his
|
||||
career?". This case is also called **weak supervision**, since the model itself must learn the appropriate
|
||||
aggregation operator (SUM/COUNT/AVERAGE/NONE) given only the answer to the question as supervision.
|
||||
3. WikiSQL-supervised: this dataset is based on WikiSQL with the model being given the ground truth aggregation
|
||||
operator during training. This is also called **strong supervision**. Here, learning the appropriate aggregation
|
||||
operator is much easier.
|
||||
|
||||
To summarize:
|
||||
|
||||
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
|
||||
| **Task** | **Example dataset** | **Description** |
|
||||
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
|
||||
| Conversational | SQA | Conversational, only cell selection questions |
|
||||
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
|
||||
| Weak supervision for aggregation | WTQ | Questions might involve aggregation, and the model must learn this given only the answer as supervision |
|
||||
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
|
||||
| Strong supervision for aggregation | WikiSQL-supervised | Questions might involve aggregation, and the model must learn this given the gold aggregation operator |
|
||||
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
|
||||
|
||||
Initializing a model with a pre-trained base and randomly initialized classification heads from the model hub can be
|
||||
done as follows (be sure to have installed the `torch-scatter dependency <https://github.com/rusty1s/pytorch_scatter>`_
|
||||
for your environment):
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import TapasConfig, TapasForQuestionAnswering
|
||||
|
||||
>>> # for example, the base sized model with default SQA configuration
|
||||
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base')
|
||||
|
||||
>>> # or, the base sized model with WTQ configuration
|
||||
>>> config = TapasConfig.from_pretrained('google/tapas-base-finetuned-wtq')
|
||||
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base', config=config)
|
||||
|
||||
>>> # or, the base sized model with WikiSQL configuration
|
||||
>>> config = TapasConfig('google-base-finetuned-wikisql-supervised')
|
||||
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base', config=config)
|
||||
|
||||
|
||||
Of course, you don't necessarily have to follow one of these three ways in which TAPAS was fine-tuned. You can also
|
||||
experiment by defining any hyperparameters you want when initializing :class:`~transformers.TapasConfig`, and then
|
||||
create a :class:`~transformers.TapasForQuestionAnswering` based on that configuration. For example, if you have a
|
||||
dataset that has both conversational questions and questions that might involve aggregation, then you can do it this
|
||||
way. Here's an example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import TapasConfig, TapasForQuestionAnswering
|
||||
|
||||
>>> # you can initialize the classification heads any way you want (see docs of TapasConfig)
|
||||
>>> config = TapasConfig(num_aggregation_labels=3, average_logits_per_cell=True, select_one_column=False)
|
||||
>>> # initializing the pre-trained base sized model with our custom classification heads
|
||||
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base', config=config)
|
||||
|
||||
What you can also do is start from an already fine-tuned checkpoint. A note here is that the already fine-tuned
|
||||
checkpoint on WTQ has some issues due to the L2-loss which is somewhat brittle. See `here
|
||||
<https://github.com/google-research/tapas/issues/91#issuecomment-735719340>`__ for more info.
|
||||
|
||||
For a list of all pre-trained and fine-tuned TAPAS checkpoints available in the HuggingFace model hub, see `here
|
||||
<https://huggingface.co/models?search=tapas>`__.
|
||||
|
||||
**STEP 2: Prepare your data in the SQA format**
|
||||
|
||||
Second, no matter what you picked above, you should prepare your dataset in the `SQA format
|
||||
<https://www.microsoft.com/en-us/download/details.aspx?id=54253>`__. This format is a TSV/CSV file with the following
|
||||
columns:
|
||||
|
||||
- ``id``: optional, id of the table-question pair, for bookkeeping purposes.
|
||||
- ``annotator``: optional, id of the person who annotated the table-question pair, for bookkeeping purposes.
|
||||
- ``position``: integer indicating if the question is the first, second, third,... related to the table. Only required
|
||||
in case of conversational setup (SQA). You don't need this column in case you're going for WTQ/WikiSQL-supervised.
|
||||
- ``question``: string
|
||||
- ``table_file``: string, name of a csv file containing the tabular data
|
||||
- ``answer_coordinates``: list of one or more tuples (each tuple being a cell coordinate, i.e. row, column pair that is
|
||||
part of the answer)
|
||||
- ``answer_text``: list of one or more strings (each string being a cell value that is part of the answer)
|
||||
- ``aggregation_label``: index of the aggregation operator. Only required in case of strong supervision for aggregation
|
||||
(the WikiSQL-supervised case)
|
||||
- ``float_answer``: the float answer to the question, if there is one (np.nan if there isn't). Only required in case of
|
||||
weak supervision for aggregation (such as WTQ and WikiSQL)
|
||||
|
||||
The tables themselves should be present in a folder, each table being a separate csv file. Note that the authors of the
|
||||
TAPAS algorithm used conversion scripts with some automated logic to convert the other datasets (WTQ, WikiSQL) into the
|
||||
SQA format. The author explains this `here
|
||||
<https://github.com/google-research/tapas/issues/50#issuecomment-705465960>`__. Interestingly, these conversion scripts
|
||||
are not perfect (the ``answer_coordinates`` and ``float_answer`` fields are populated based on the ``answer_text``),
|
||||
meaning that WTQ and WikiSQL results could actually be improved.
|
||||
|
||||
**STEP 3: Convert your data into PyTorch tensors using TapasTokenizer**
|
||||
|
||||
Third, given that you've prepared your data in this TSV/CSV format (and corresponding CSV files containing the tabular
|
||||
data), you can then use :class:`~transformers.TapasTokenizer` to convert table-question pairs into :obj:`input_ids`,
|
||||
:obj:`attention_mask`, :obj:`token_type_ids` and so on. Again, based on which of the three cases you picked above,
|
||||
:class:`~transformers.TapasForQuestionAnswering` requires different inputs to be fine-tuned:
|
||||
|
||||
+------------------------------------+----------------------------------------------------------------------------------------------+
|
||||
| **Task** | **Required inputs** |
|
||||
+------------------------------------+----------------------------------------------------------------------------------------------+
|
||||
| Conversational | ``input_ids``, ``attention_mask``, ``token_type_ids``, ``labels`` |
|
||||
+------------------------------------+----------------------------------------------------------------------------------------------+
|
||||
| Weak supervision for aggregation | ``input_ids``, ``attention_mask``, ``token_type_ids``, ``labels``, ``numeric_values``, |
|
||||
| | ``numeric_values_scale``, ``float_answer`` |
|
||||
+------------------------------------+----------------------------------------------------------------------------------------------+
|
||||
| Strong supervision for aggregation | ``input ids``, ``attention mask``, ``token type ids``, ``labels``, ``aggregation_labels`` |
|
||||
+------------------------------------+----------------------------------------------------------------------------------------------+
|
||||
|
||||
:class:`~transformers.TapasTokenizer` creates the ``labels``, ``numeric_values`` and ``numeric_values_scale`` based on
|
||||
the ``answer_coordinates`` and ``answer_text`` columns of the TSV file. The ``float_answer`` and ``aggregation_labels``
|
||||
are already in the TSV file of step 2. Here's an example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import TapasTokenizer
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> model_name = 'google/tapas-base'
|
||||
>>> tokenizer = TapasTokenizer.from_pretrained(model_name)
|
||||
|
||||
>>> data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], 'Number of movies': ["87", "53", "69"]}
|
||||
>>> queries = ["What is the name of the first actor?", "How many movies has George Clooney played in?", "What is the total number of movies?"]
|
||||
>>> answer_coordinates = [[(0, 0)], [(2, 1)], [(0, 1), (1, 1), (2, 1)]]
|
||||
>>> answer_text = [["Brad Pitt"], ["69"], ["209"]]
|
||||
>>> table = pd.DataFrame.from_dict(data)
|
||||
>>> inputs = tokenizer(table=table, queries=queries, answer_coordinates=answer_coordinates, answer_text=answer_text, padding='max_length', return_tensors='pt')
|
||||
>>> inputs
|
||||
{'input_ids': tensor([[ ... ]]), 'attention_mask': tensor([[...]]), 'token_type_ids': tensor([[[...]]]),
|
||||
'numeric_values': tensor([[ ... ]]), 'numeric_values_scale: tensor([[ ... ]]), labels: tensor([[ ... ]])}
|
||||
|
||||
Note that :class:`~transformers.TapasTokenizer` expects the data of the table to be **text-only**. You can use
|
||||
``.astype(str)`` on a dataframe to turn it into text-only data. Of course, this only shows how to encode a single
|
||||
training example. It is advised to create a PyTorch dataset and a corresponding dataloader:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> import torch
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> tsv_path = "your_path_to_the_tsv_file"
|
||||
>>> table_csv_path = "your_path_to_a_directory_containing_all_csv_files"
|
||||
|
||||
>>> class TableDataset(torch.utils.data.Dataset):
|
||||
... def __init__(self, data, tokenizer):
|
||||
... self.data = data
|
||||
... self.tokenizer = tokenizer
|
||||
...
|
||||
... def __getitem__(self, idx):
|
||||
... item = data.iloc[idx]
|
||||
... table = pd.read_csv(table_csv_path + item.table_file).astype(str) # be sure to make your table data text only
|
||||
... encoding = self.tokenizer(table=table,
|
||||
... queries=item.question,
|
||||
... answer_coordinates=item.answer_coordinates,
|
||||
... answer_text=item.answer_text,
|
||||
... truncation=True,
|
||||
... padding="max_length",
|
||||
... return_tensors="pt"
|
||||
... )
|
||||
... # remove the batch dimension which the tokenizer adds by default
|
||||
... encoding = {key: val.squeeze(0) for key, val in encoding.items()}
|
||||
... # add the float_answer which is also required (weak supervision for aggregation case)
|
||||
... encoding["float_answer"] = torch.tensor(item.float_answer)
|
||||
... return encoding
|
||||
...
|
||||
... def __len__(self):
|
||||
... return len(self.data)
|
||||
|
||||
>>> data = pd.read_csv(tsv_path, sep='\t')
|
||||
>>> train_dataset = TableDataset(data, tokenizer)
|
||||
>>> train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
|
||||
|
||||
Note that here, we encode each table-question pair independently. This is fine as long as your dataset is **not
|
||||
conversational**. In case your dataset involves conversational questions (such as in SQA), then you should first group
|
||||
together the ``queries``, ``answer_coordinates`` and ``answer_text`` per table (in the order of their ``position``
|
||||
index) and batch encode each table with its questions. This will make sure that the ``prev_labels`` token types (see
|
||||
docs of :class:`~transformers.TapasTokenizer`) are set correctly. See `this notebook
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__
|
||||
for more info.
|
||||
|
||||
**STEP 4: Train (fine-tune) TapasForQuestionAnswering**
|
||||
|
||||
You can then fine-tune :class:`~transformers.TapasForQuestionAnswering` using native PyTorch as follows (shown here for
|
||||
the weak supervision for aggregation case):
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import TapasConfig, TapasForQuestionAnswering, AdamW
|
||||
|
||||
>>> # this is the default WTQ configuration
|
||||
>>> config = TapasConfig(
|
||||
... num_aggregation_labels = 4,
|
||||
... use_answer_as_supervision = True,
|
||||
... answer_loss_cutoff = 0.664694,
|
||||
... cell_selection_preference = 0.207951,
|
||||
... huber_loss_delta = 0.121194,
|
||||
... init_cell_selection_weights_to_zero = True,
|
||||
... select_one_column = True,
|
||||
... allow_empty_column_selection = False,
|
||||
... temperature = 0.0352513,
|
||||
... )
|
||||
>>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base", config=config)
|
||||
|
||||
>>> optimizer = AdamW(model.parameters(), lr=5e-5)
|
||||
|
||||
>>> for epoch in range(2): # loop over the dataset multiple times
|
||||
... for idx, batch in enumerate(train_dataloader):
|
||||
... # get the inputs;
|
||||
... input_ids = batch["input_ids"]
|
||||
... attention_mask = batch["attention_mask"]
|
||||
... token_type_ids = batch["token_type_ids"]
|
||||
... labels = batch["labels"]
|
||||
... numeric_values = batch["numeric_values"]
|
||||
... numeric_values_scale = batch["numeric_values_scale"]
|
||||
... float_answer = batch["float_answer"]
|
||||
|
||||
... # zero the parameter gradients
|
||||
... optimizer.zero_grad()
|
||||
|
||||
... # forward + backward + optimize
|
||||
... outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
|
||||
... labels=labels, numeric_values=numeric_values, numeric_values_scale=numeric_values_scale,
|
||||
... float_answer=float_answer)
|
||||
... loss = outputs.loss
|
||||
... loss.backward()
|
||||
... optimizer.step()
|
||||
|
||||
Usage: inference
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Here we explain how you can use :class:`~transformers.TapasForQuestionAnswering` for inference (i.e. making predictions
|
||||
on new data). For inference, only ``input_ids``, ``attention_mask`` and ``token_type_ids`` (which you can obtain using
|
||||
:class:`~transformers.TapasTokenizer`) have to be provided to the model to obtain the logits. Next, you can use the
|
||||
handy ``convert_logits_to_predictions`` method of :class:`~transformers.TapasTokenizer` to convert these into predicted
|
||||
coordinates and optional aggregation indices.
|
||||
|
||||
However, note that inference is **different** depending on whether or not the setup is conversational. In a
|
||||
non-conversational set-up, inference can be done in parallel on all table-question pairs of a batch. Here's an example
|
||||
of that:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import TapasTokenizer, TapasForQuestionAnswering
|
||||
>>> import pandas as pd
|
||||
|
||||
>>> model_name = 'google/tapas-base-finetuned-wtq'
|
||||
>>> model = TapasForQuestionAnswering.from_pretrained(model_name)
|
||||
>>> tokenizer = TapasTokenizer.from_pretrained(model_name)
|
||||
|
||||
>>> data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], 'Number of movies': ["87", "53", "69"]}
|
||||
>>> queries = ["What is the name of the first actor?", "How many movies has George Clooney played in?", "What is the total number of movies?"]
|
||||
>>> table = pd.DataFrame.from_dict(data)
|
||||
>>> inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
>>> predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
|
||||
... inputs,
|
||||
... outputs.logits.detach(),
|
||||
... outputs.logits_aggregation.detach()
|
||||
...)
|
||||
|
||||
>>> # let's print out the results:
|
||||
>>> id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
|
||||
>>> aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
|
||||
|
||||
>>> answers = []
|
||||
>>> for coordinates in predicted_answer_coordinates:
|
||||
... if len(coordinates) == 1:
|
||||
... # only a single cell:
|
||||
... answers.append(table.iat[coordinates[0]])
|
||||
... else:
|
||||
... # multiple cells
|
||||
... cell_values = []
|
||||
... for coordinate in coordinates:
|
||||
... cell_values.append(table.iat[coordinate])
|
||||
... answers.append(", ".join(cell_values))
|
||||
|
||||
>>> display(table)
|
||||
>>> print("")
|
||||
>>> for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
|
||||
... print(query)
|
||||
... if predicted_agg == "NONE":
|
||||
... print("Predicted answer: " + answer)
|
||||
... else:
|
||||
... print("Predicted answer: " + predicted_agg + " > " + answer)
|
||||
What is the name of the first actor?
|
||||
Predicted answer: Brad Pitt
|
||||
How many movies has George Clooney played in?
|
||||
Predicted answer: COUNT > 69
|
||||
What is the total number of movies?
|
||||
Predicted answer: SUM > 87, 53, 69
|
||||
|
||||
In case of a conversational set-up, then each table-question pair must be provided **sequentially** to the model, such
|
||||
that the ``prev_labels`` token types can be overwritten by the predicted ``labels`` of the previous table-question
|
||||
pair. Again, more info can be found in `this notebook
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__.
|
||||
|
||||
|
||||
Tapas specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.models.tapas.modeling_tapas.TableQuestionAnsweringOutput
|
||||
:members:
|
||||
|
||||
|
||||
TapasConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TapasConfig
|
||||
:members:
|
||||
|
||||
|
||||
TapasTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TapasTokenizer
|
||||
:members: __call__, convert_logits_to_predictions, save_vocabulary
|
||||
|
||||
|
||||
TapasModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TapasModel
|
||||
:members: forward
|
||||
|
||||
|
||||
TapasForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TapasForMaskedLM
|
||||
:members: forward
|
||||
|
||||
|
||||
TapasForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TapasForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TapasForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TapasForQuestionAnswering
|
||||
:members: forward
|
123
model_cards/google/tapas-base/README.md
Normal file
123
model_cards/google/tapas-base/README.md
Normal file
@ -0,0 +1,123 @@
|
||||
---
|
||||
language: en
|
||||
tags:
|
||||
- tapas
|
||||
- masked-lm
|
||||
license: apache-2.0
|
||||
---
|
||||
|
||||
# TAPAS base model
|
||||
|
||||
This model corresponds to the `tapas_inter_masklm_base_reset` checkpoint of the [original Github repository](https://github.com/google-research/tapas).
|
||||
|
||||
Disclaimer: The team releasing TAPAS did not write a model card for this model so this model card has been written by
|
||||
the Hugging Face team and contributors.
|
||||
|
||||
## Model description
|
||||
|
||||
TAPAS is a BERT-like transformers model pretrained on a large corpus of English data from Wikipedia in a self-supervised fashion.
|
||||
This means it was pretrained on the raw tables and associated texts only, with no humans labelling them in any way (which is why it
|
||||
can use lots of publicly available data) with an automatic process to generate inputs and labels from those texts. More precisely, it
|
||||
was pretrained with two objectives:
|
||||
|
||||
- Masked language modeling (MLM): taking a (flattened) table and associated context, the model randomly masks 15% of the words in
|
||||
the input, then runs the entire (partially masked) sequence through the model. The model then has to predict the masked words.
|
||||
This is different from traditional recurrent neural networks (RNNs) that usually see the words one after the other,
|
||||
or from autoregressive models like GPT which internally mask the future tokens. It allows the model to learn a bidirectional
|
||||
representation of a table and associated text.
|
||||
- Intermediate pre-training: to encourage numerical reasoning on tables, the authors additionally pre-trained the model by creating
|
||||
a balanced dataset of millions of syntactically created training examples. Here, the model must predict (classify) whether a sentence
|
||||
is supported or refuted by the contents of a table. The training examples are created based on synthetic as well as counterfactual statements.
|
||||
|
||||
This way, the model learns an inner representation of the English language used in tables and associated texts, which can then be used
|
||||
to extract features useful for downstream tasks such as answering questions about a table, or determining whether a sentence is entailed
|
||||
or refuted by the contents of a table. Fine-tuning is done by adding classification heads on top of the pre-trained model, and then jointly
|
||||
train the randomly initialized classification heads with the base model on a labelled dataset.
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
You can use the raw model for masked language modeling, but it's mostly intended to be fine-tuned on a downstream task.
|
||||
See the [model hub](https://huggingface.co/models?filter=tapas) to look for fine-tuned versions on a task that interests you.
|
||||
|
||||
|
||||
Here is how to use this model to get the features of a given table-text pair in PyTorch:
|
||||
|
||||
```python
|
||||
from transformers import TapasTokenizer, TapasModel
|
||||
import pandas as pd
|
||||
tokenizer = TapasTokenizer.from_pretrained('tapase-base')
|
||||
model = TapasModel.from_pretrained("tapas-base")
|
||||
data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
|
||||
'Age': ["56", "45", "59"],
|
||||
'Number of movies': ["87", "53", "69"]
|
||||
}
|
||||
table = pd.DataFrame.from_dict(data)
|
||||
queries = ["How many movies has George Clooney played in?"]
|
||||
text = "Replace me by any text you'd like."
|
||||
encoded_input = tokenizer(table=table, queries=queries, return_tensors='pt')
|
||||
output = model(**encoded_input)
|
||||
```
|
||||
|
||||
## Training data
|
||||
|
||||
For masked language modeling (MLM), a collection of 6.2 million tables was extracted from English Wikipedia: 3.3M of class [Infobox](https://en.wikipedia.org/wiki/Help:Infobox)
|
||||
and 2.9M of class WikiTable. The author only considered tables with at most 500 cells. As a proxy for questions that appear in the
|
||||
downstream tasks, the authros extracted the table caption, article title, article description, segment title and text of the segment
|
||||
the table occurs in as relevant text snippets. In this way, 21.3M snippets were created. For more info, see the original [TAPAS paper](https://www.aclweb.org/anthology/2020.acl-main.398.pdf).
|
||||
|
||||
For intermediate pre-training, 2 tasks are introduced: one based on synthetic and the other from counterfactual statements. The first one
|
||||
generates a sentence by sampling from a set of logical expressions that filter, combine and compare the information on the table, which is
|
||||
required in table entailment (e.g., knowing that Gerald Ford is taller than the average president requires summing
|
||||
all presidents and dividing by the number of presidents). The second one corrupts sentences about tables appearing on Wikipedia by swapping
|
||||
entities for plausible alternatives. Examples of the two tasks can be seen in Figure 1. The procedure is described in detail in section 3 of
|
||||
the [TAPAS follow-up paper](https://www.aclweb.org/anthology/2020.findings-emnlp.27.pdf).
|
||||
|
||||
## Training procedure
|
||||
|
||||
### Preprocessing
|
||||
|
||||
The texts are lowercased and tokenized using WordPiece and a vocabulary size of 30,000. The inputs of the model are
|
||||
then of the form:
|
||||
|
||||
```
|
||||
[CLS] Context [SEP] Flattened table [SEP]
|
||||
```
|
||||
|
||||
The details of the masking procedure for each sequence are the following:
|
||||
- 15% of the tokens are masked.
|
||||
- In 80% of the cases, the masked tokens are replaced by `[MASK]`.
|
||||
- In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace.
|
||||
- In the 10% remaining cases, the masked tokens are left as is.
|
||||
|
||||
The details of the creation of the synthetic and counterfactual examples can be found in the [follow-up paper](https://arxiv.org/abs/2010.00571).
|
||||
|
||||
### Pretraining
|
||||
|
||||
The model was trained on 32 Cloud TPU v3 cores for one million steps with maximum sequence length 512 and batch size of 512.
|
||||
In this setup, pre-training takes around 3 days. The optimizer used is Adam with a learning rate of 5e-5, and a warmup ratio
|
||||
of 0.10.
|
||||
|
||||
|
||||
### BibTeX entry and citation info
|
||||
|
||||
```bibtex
|
||||
@misc{herzig2020tapas,
|
||||
title={TAPAS: Weakly Supervised Table Parsing via Pre-training},
|
||||
author={Jonathan Herzig and Paweł Krzysztof Nowak and Thomas Müller and Francesco Piccinno and Julian Martin Eisenschlos},
|
||||
year={2020},
|
||||
eprint={2004.02349},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.IR}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{eisenschlos2020understanding,
|
||||
title={Understanding tables with intermediate pre-training},
|
||||
author={Julian Martin Eisenschlos and Syrine Krichene and Thomas Müller},
|
||||
year={2020},
|
||||
eprint={2010.00571},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
@ -1,71 +1,73 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# 🤗 Transformers Notebooks
|
||||
|
||||
You can find here a list of the official notebooks provided by Hugging Face.
|
||||
|
||||
Also, we would like to list here interesting content created by the community.
|
||||
If you wrote some notebook(s) leveraging 🤗 Transformers and would like be listed here, please open a
|
||||
Pull Request so it can be included under the Community notebooks.
|
||||
|
||||
|
||||
## Hugging Face's notebooks 🤗
|
||||
|
||||
|
||||
| Notebook | Description | |
|
||||
|:----------|:-------------|------:|
|
||||
| [Getting Started Tokenizers](https://github.com/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) | How to train and use your very own tokenizer |[](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) |
|
||||
| [Getting Started Transformers](https://github.com/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) | How to easily start using transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) |
|
||||
| [How to use Pipelines](https://github.com/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) | Simple and efficient way to use State-of-the-Art models on downstream tasks through transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) |
|
||||
| [How to train a language model](https://github.com/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)|
|
||||
| [How to generate text](https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)|
|
||||
| [How to export model to ONNX](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb) | Highlight how to export and run inference workloads through ONNX |
|
||||
| [How to use Benchmarks](https://github.com/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb) | How to benchmark models with transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb)|
|
||||
| [Reformer](https://github.com/huggingface/blog/blob/master/notebooks/03_reformer.ipynb) | How Reformer pushes the limits of language modeling | [](https://colab.research.google.com/github/patrickvonplaten/blog/blob/master/notebooks/03_reformer.ipynb)|
|
||||
|
||||
|
||||
## Community notebooks:
|
||||
|
||||
| Notebook | Description | Author | |
|
||||
|:----------|:-------------|:-------------|------:|
|
||||
| [Train T5 in Tensorflow 2 ](https://github.com/snapthat/TF-T5-text-to-text) | How to train T5 for any task using Tensorflow 2. This notebook demonstrates a Question & Answer task implemented in Tensorflow 2 using SQUAD | [Muhammad Harris](https://github.com/HarrisDePerceptron) |[](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb) |
|
||||
| [Train T5 on TPU](https://github.com/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb) | How to train T5 on SQUAD with Transformers and Nlp | [Suraj Patil](https://github.com/patil-suraj) |[](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil) |
|
||||
| [Fine-tune T5 for Classification and Multiple Choice](https://github.com/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) | How to fine-tune T5 for classification and multiple choice tasks using a text-to-text format with PyTorch Lightning | [Suraj Patil](https://github.com/patil-suraj) | [](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) |
|
||||
| [Fine-tune DialoGPT on New Datasets and Languages](https://github.com/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) | How to fine-tune the DialoGPT model on a new dataset for open-dialog conversational chatbots | [Nathan Cooper](https://github.com/ncoop57) | [](https://colab.research.google.com/github/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) |
|
||||
| [Long Sequence Modeling with Reformer](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) | How to train on sequences as long as 500,000 tokens with Reformer | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) |
|
||||
| [Fine-tune BART for Summarization](https://github.com/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) | How to fine-tune BART for summarization with fastai using blurr | [Wayde Gilliam](https://ohmeow.com/) | [](https://colab.research.google.com/github/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) |
|
||||
| [Fine-tune a pre-trained Transformer on anyone's tweets](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) | How to generate tweets in the style of your favorite Twitter account by fine-tune a GPT-2 model | [Boris Dayma](https://github.com/borisdayma) | [](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) |
|
||||
| [A Step by Step Guide to Tracking Hugging Face Model Performance](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) | A quick tutorial for training NLP models with HuggingFace and & visualizing their performance with Weights & Biases | [Jack Morris](https://github.com/jxmorris12) | [](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) |
|
||||
| [Pretrain Longformer](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) | How to build a "long" version of existing pretrained models | [Iz Beltagy](https://beltagy.net) | [](https://colab.research.google.com/github/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) |
|
||||
| [Fine-tune Longformer for QA](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) | How to fine-tune longformer model for QA task | [Suraj Patil](https://github.com/patil-suraj) | [](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) |
|
||||
| [Evaluate Model with 🤗nlp](https://github.com/patrickvonplaten/notebooks/blob/master/How_to_evaluate_Longformer_on_TriviaQA_using_NLP.ipynb) | How to evaluate longformer on TriviaQA with `nlp` | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/drive/1m7eTGlPmLRgoPkkA7rkhQdZ9ydpmsdLE?usp=sharing) |
|
||||
| [Fine-tune T5 for Sentiment Span Extraction](https://github.com/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) | How to fine-tune T5 for sentiment span extraction using a text-to-text format with PyTorch Lightning | [Lorenzo Ampil](https://github.com/enzoampil) | [](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) |
|
||||
| [Fine-tune DistilBert for Multiclass Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb) | How to fine-tune DistilBert for multiclass classification with PyTorch | [Abhishek Kumar Mishra](https://github.com/abhimishra91) | [](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb)|
|
||||
|[Fine-tune BERT for Multi-label Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|How to fine-tune BERT for multi-label classification using PyTorch|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|
|
||||
|[Fine-tune T5 for Summarization](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|How to fine-tune T5 for summarization in PyTorch and track experiments with WandB|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|
|
||||
|[Speed up Fine-Tuning in Transformers with Dynamic Padding / Bucketing](https://github.com/ELS-RD/transformers-notebook/blob/master/Divide_Hugging_Face_Transformers_training_time_by_2_or_more.ipynb)|How to speed up fine-tuning by a factor of 2 using dynamic padding / bucketing|[Michael Benesty](https://github.com/pommedeterresautee) |[](https://colab.research.google.com/drive/1CBfRU1zbfu7-ijiOqAAQUA-RJaxfcJoO?usp=sharing)|
|
||||
|[Pretrain Reformer for Masked Language Modeling](https://github.com/patrickvonplaten/notebooks/blob/master/Reformer_For_Masked_LM.ipynb)| How to train a Reformer model with bi-directional self-attention layers | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/drive/1tzzh0i8PgDQGV3SMFUGxM7_gGae3K-uW?usp=sharing)|
|
||||
|[Expand and Fine Tune Sci-BERT](https://github.com/lordtt13/word-embeddings/blob/master/COVID-19%20Research%20Data/COVID-SciBERT.ipynb)| How to increase vocabulary of a pretrained SciBERT model from AllenAI on the CORD dataset and pipeline it. | [Tanmay Thakur](https://github.com/lordtt13) | [](https://colab.research.google.com/drive/1rqAR40goxbAfez1xvF3hBJphSCsvXmh8)|
|
||||
|[Fine-tune Electra and interpret with Integrated Gradients](https://github.com/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb) | How to fine-tune Electra for sentiment analysis and interpret predictions with Captum Integrated Gradients | [Eliza Szczechla](https://elsanns.github.io) | [](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb)|
|
||||
|[fine-tune a non-English GPT-2 Model with Trainer class](https://github.com/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb) | How to fine-tune a non-English GPT-2 Model with Trainer class | [Philipp Schmid](https://www.philschmid.de) | [](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb)|
|
||||
|[Fine-tune a DistilBERT Model for Multi Label Classification task](https://github.com/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb) | How to fine-tune a DistilBERT Model for Multi Label Classification task | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb)|
|
||||
|[Fine-tune ALBERT for sentence-pair classification](https://github.com/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb) | How to fine-tune an ALBERT model or another BERT-based model for the sentence-pair classification task | [Nadir El Manouzi](https://github.com/NadirEM) | [](https://colab.research.google.com/github/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb)|
|
||||
|[Fine-tune Roberta for sentiment analysis](https://github.com/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb) | How to fine-tune an Roberta model for sentiment analysis | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [](https://colab.research.google.com/github/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb)|
|
||||
|[Evaluating Question Generation Models](https://github.com/flexudy-pipe/qugeev) | How accurate are the answers to questions generated by your seq2seq transformer model? | [Pascal Zoleko](https://github.com/zolekode) | [](https://colab.research.google.com/drive/1bpsSqCQU-iw_5nNoRm_crPq6FRuJthq_?usp=sharing)|
|
||||
|[Classify text with DistilBERT and Tensorflow](https://github.com/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb) | How to fine-tune DistilBERT for text classification in TensorFlow | [Peter Bayerle](https://github.com/peterbayerle) | [](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb)|
|
||||
|[Leverage BERT for Encoder-Decoder Summarization on CNN/Dailymail](https://github.com/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb) | How to warm-start a *EncoderDecoderModel* with a *bert-base-uncased* checkpoint for summarization on CNN/Dailymail | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb)|
|
||||
|[Leverage RoBERTa for Encoder-Decoder Summarization on BBC XSum](https://github.com/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb) | How to warm-start a shared *EncoderDecoderModel* with a *roberta-base* checkpoint for summarization on BBC/XSum | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb)|
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# 🤗 Transformers Notebooks
|
||||
|
||||
You can find here a list of the official notebooks provided by Hugging Face.
|
||||
|
||||
Also, we would like to list here interesting content created by the community.
|
||||
If you wrote some notebook(s) leveraging 🤗 Transformers and would like be listed here, please open a
|
||||
Pull Request so it can be included under the Community notebooks.
|
||||
|
||||
|
||||
## Hugging Face's notebooks 🤗
|
||||
|
||||
|
||||
| Notebook | Description | |
|
||||
|:----------|:-------------|------:|
|
||||
| [Getting Started Tokenizers](https://github.com/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) | How to train and use your very own tokenizer |[](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) |
|
||||
| [Getting Started Transformers](https://github.com/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) | How to easily start using transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) |
|
||||
| [How to use Pipelines](https://github.com/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) | Simple and efficient way to use State-of-the-Art models on downstream tasks through transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) |
|
||||
| [How to train a language model](https://github.com/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)|
|
||||
| [How to generate text](https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)|
|
||||
| [How to export model to ONNX](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb) | Highlight how to export and run inference workloads through ONNX |
|
||||
| [How to use Benchmarks](https://github.com/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb) | How to benchmark models with transformers | [](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb)|
|
||||
| [Reformer](https://github.com/huggingface/blog/blob/master/notebooks/03_reformer.ipynb) | How Reformer pushes the limits of language modeling | [](https://colab.research.google.com/github/patrickvonplaten/blog/blob/master/notebooks/03_reformer.ipynb)|
|
||||
|
||||
|
||||
## Community notebooks:
|
||||
|
||||
| Notebook | Description | Author | |
|
||||
|:----------|:-------------|:-------------|------:|
|
||||
| [Train T5 in Tensorflow 2 ](https://github.com/snapthat/TF-T5-text-to-text) | How to train T5 for any task using Tensorflow 2. This notebook demonstrates a Question & Answer task implemented in Tensorflow 2 using SQUAD | [Muhammad Harris](https://github.com/HarrisDePerceptron) |[](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb) |
|
||||
| [Train T5 on TPU](https://github.com/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb) | How to train T5 on SQUAD with Transformers and Nlp | [Suraj Patil](https://github.com/patil-suraj) |[](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil) |
|
||||
| [Fine-tune T5 for Classification and Multiple Choice](https://github.com/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) | How to fine-tune T5 for classification and multiple choice tasks using a text-to-text format with PyTorch Lightning | [Suraj Patil](https://github.com/patil-suraj) | [](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) |
|
||||
| [Fine-tune DialoGPT on New Datasets and Languages](https://github.com/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) | How to fine-tune the DialoGPT model on a new dataset for open-dialog conversational chatbots | [Nathan Cooper](https://github.com/ncoop57) | [](https://colab.research.google.com/github/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) |
|
||||
| [Long Sequence Modeling with Reformer](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) | How to train on sequences as long as 500,000 tokens with Reformer | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) |
|
||||
| [Fine-tune BART for Summarization](https://github.com/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) | How to fine-tune BART for summarization with fastai using blurr | [Wayde Gilliam](https://ohmeow.com/) | [](https://colab.research.google.com/github/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) |
|
||||
| [Fine-tune a pre-trained Transformer on anyone's tweets](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) | How to generate tweets in the style of your favorite Twitter account by fine-tune a GPT-2 model | [Boris Dayma](https://github.com/borisdayma) | [](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) |
|
||||
| [A Step by Step Guide to Tracking Hugging Face Model Performance](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) | A quick tutorial for training NLP models with HuggingFace and & visualizing their performance with Weights & Biases | [Jack Morris](https://github.com/jxmorris12) | [](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) |
|
||||
| [Pretrain Longformer](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) | How to build a "long" version of existing pretrained models | [Iz Beltagy](https://beltagy.net) | [](https://colab.research.google.com/github/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) |
|
||||
| [Fine-tune Longformer for QA](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) | How to fine-tune longformer model for QA task | [Suraj Patil](https://github.com/patil-suraj) | [](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) |
|
||||
| [Evaluate Model with 🤗nlp](https://github.com/patrickvonplaten/notebooks/blob/master/How_to_evaluate_Longformer_on_TriviaQA_using_NLP.ipynb) | How to evaluate longformer on TriviaQA with `nlp` | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/drive/1m7eTGlPmLRgoPkkA7rkhQdZ9ydpmsdLE?usp=sharing) |
|
||||
| [Fine-tune T5 for Sentiment Span Extraction](https://github.com/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) | How to fine-tune T5 for sentiment span extraction using a text-to-text format with PyTorch Lightning | [Lorenzo Ampil](https://github.com/enzoampil) | [](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) |
|
||||
| [Fine-tune DistilBert for Multiclass Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb) | How to fine-tune DistilBert for multiclass classification with PyTorch | [Abhishek Kumar Mishra](https://github.com/abhimishra91) | [](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb)|
|
||||
|[Fine-tune BERT for Multi-label Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|How to fine-tune BERT for multi-label classification using PyTorch|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|
|
||||
|[Fine-tune T5 for Summarization](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|How to fine-tune T5 for summarization in PyTorch and track experiments with WandB|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|
|
||||
|[Speed up Fine-Tuning in Transformers with Dynamic Padding / Bucketing](https://github.com/ELS-RD/transformers-notebook/blob/master/Divide_Hugging_Face_Transformers_training_time_by_2_or_more.ipynb)|How to speed up fine-tuning by a factor of 2 using dynamic padding / bucketing|[Michael Benesty](https://github.com/pommedeterresautee) |[](https://colab.research.google.com/drive/1CBfRU1zbfu7-ijiOqAAQUA-RJaxfcJoO?usp=sharing)|
|
||||
|[Pretrain Reformer for Masked Language Modeling](https://github.com/patrickvonplaten/notebooks/blob/master/Reformer_For_Masked_LM.ipynb)| How to train a Reformer model with bi-directional self-attention layers | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/drive/1tzzh0i8PgDQGV3SMFUGxM7_gGae3K-uW?usp=sharing)|
|
||||
|[Expand and Fine Tune Sci-BERT](https://github.com/lordtt13/word-embeddings/blob/master/COVID-19%20Research%20Data/COVID-SciBERT.ipynb)| How to increase vocabulary of a pretrained SciBERT model from AllenAI on the CORD dataset and pipeline it. | [Tanmay Thakur](https://github.com/lordtt13) | [](https://colab.research.google.com/drive/1rqAR40goxbAfez1xvF3hBJphSCsvXmh8)|
|
||||
|[Fine-tune Electra and interpret with Integrated Gradients](https://github.com/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb) | How to fine-tune Electra for sentiment analysis and interpret predictions with Captum Integrated Gradients | [Eliza Szczechla](https://elsanns.github.io) | [](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb)|
|
||||
|[fine-tune a non-English GPT-2 Model with Trainer class](https://github.com/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb) | How to fine-tune a non-English GPT-2 Model with Trainer class | [Philipp Schmid](https://www.philschmid.de) | [](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb)|
|
||||
|[Fine-tune a DistilBERT Model for Multi Label Classification task](https://github.com/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb) | How to fine-tune a DistilBERT Model for Multi Label Classification task | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb)|
|
||||
|[Fine-tune ALBERT for sentence-pair classification](https://github.com/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb) | How to fine-tune an ALBERT model or another BERT-based model for the sentence-pair classification task | [Nadir El Manouzi](https://github.com/NadirEM) | [](https://colab.research.google.com/github/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb)|
|
||||
|[Fine-tune Roberta for sentiment analysis](https://github.com/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb) | How to fine-tune an Roberta model for sentiment analysis | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [](https://colab.research.google.com/github/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb)|
|
||||
|[Evaluating Question Generation Models](https://github.com/flexudy-pipe/qugeev) | How accurate are the answers to questions generated by your seq2seq transformer model? | [Pascal Zoleko](https://github.com/zolekode) | [](https://colab.research.google.com/drive/1bpsSqCQU-iw_5nNoRm_crPq6FRuJthq_?usp=sharing)|
|
||||
|[Classify text with DistilBERT and Tensorflow](https://github.com/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb) | How to fine-tune DistilBERT for text classification in TensorFlow | [Peter Bayerle](https://github.com/peterbayerle) | [](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb)|
|
||||
|[Leverage BERT for Encoder-Decoder Summarization on CNN/Dailymail](https://github.com/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb) | How to warm-start a *EncoderDecoderModel* with a *bert-base-uncased* checkpoint for summarization on CNN/Dailymail | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb)|
|
||||
|[Leverage RoBERTa for Encoder-Decoder Summarization on BBC XSum](https://github.com/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb) | How to warm-start a shared *EncoderDecoderModel* with a *roberta-base* checkpoint for summarization on BBC/XSum | [Patrick von Platen](https://github.com/patrickvonplaten) | [](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb)|
|
||||
|[Fine-tuning TAPAS on Sequential Question Answering (SQA)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb) | How to fine-tune *TapasForQuestionAnswering* with a *tapas-base* checkpoint on the Sequential Question Answering (SQA) dataset | [Niels Rogge](https://github.com/nielsrogge) | [](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQAipynb)|
|
||||
|[Evaluating TAPAS on Table Fact Checking (TabFact)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb) | How to evaluate a fine-tuned *TapasForSequenceClassification* with a *tapas-base-finetuned-tabfact* checkpoint using a combination of the 🤗 datasets and 🤗 transformers libraries | [Niels Rogge](https://github.com/nielsrogge) | [](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb)|
|
||||
|
@ -164,6 +164,7 @@ from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBert
|
||||
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
|
||||
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
|
||||
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
|
||||
from .models.transfo_xl import (
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TransfoXLConfig,
|
||||
@ -605,6 +606,13 @@ if is_torch_available():
|
||||
T5PreTrainedModel,
|
||||
load_tf_weights_in_t5,
|
||||
)
|
||||
from .models.tapas import (
|
||||
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TapasForMaskedLM,
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
)
|
||||
from .models.transfo_xl import (
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
AdaptiveEmbedding,
|
||||
|
@ -216,6 +216,29 @@ except ImportError:
|
||||
_tokenizers_available = False
|
||||
|
||||
|
||||
try:
|
||||
import pandas # noqa: F401
|
||||
|
||||
_pandas_available = True
|
||||
|
||||
except ImportError:
|
||||
_pandas_available = False
|
||||
|
||||
|
||||
try:
|
||||
import torch_scatter
|
||||
|
||||
# Check we're not importing a "torch_scatter" directory somewhere
|
||||
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
|
||||
if _scatter_available:
|
||||
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
|
||||
else:
|
||||
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
|
||||
|
||||
except ImportError:
|
||||
_scatter_available = False
|
||||
|
||||
|
||||
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
||||
# New default cache, shared with the Datasets library
|
||||
hf_cache_home = os.path.expanduser(
|
||||
@ -325,6 +348,14 @@ def is_in_notebook():
|
||||
return _in_notebook
|
||||
|
||||
|
||||
def is_scatter_available():
|
||||
return _scatter_available
|
||||
|
||||
|
||||
def is_pandas_available():
|
||||
return _pandas_available
|
||||
|
||||
|
||||
def torch_only_method(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _torch_available:
|
||||
@ -427,6 +458,13 @@ installation page: https://github.com/google/flax and follow the ones that match
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
SCATTER_IMPORT_ERROR = """
|
||||
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
|
||||
explained here: https://github.com/rusty1s/pytorch_scatter.
|
||||
"""
|
||||
|
||||
|
||||
def requires_datasets(obj):
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
if not is_datasets_available():
|
||||
@ -481,6 +519,12 @@ def requires_protobuf(obj):
|
||||
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
|
||||
|
||||
|
||||
def requires_scatter(obj):
|
||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||
if not is_scatter_available():
|
||||
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
|
||||
|
||||
|
||||
def add_start_docstrings(*docstr):
|
||||
def docstring_decorator(fn):
|
||||
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
||||
|
@ -51,6 +51,7 @@ from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCH
|
||||
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
|
||||
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
||||
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
|
||||
@ -95,6 +96,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
@ -141,6 +143,7 @@ CONFIG_MAPPING = OrderedDict(
|
||||
("dpr", DPRConfig),
|
||||
("layoutlm", LayoutLMConfig),
|
||||
("rag", RagConfig),
|
||||
("tapas", TapasConfig),
|
||||
]
|
||||
)
|
||||
|
||||
@ -185,6 +188,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("prophetnet", "ProphetNet"),
|
||||
("mt5", "mT5"),
|
||||
("mpnet", "MPNet"),
|
||||
("tapas", "TAPAS"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -165,6 +165,12 @@ from ..squeezebert.modeling_squeezebert import (
|
||||
SqueezeBertModel,
|
||||
)
|
||||
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
|
||||
from ..tapas.modeling_tapas import (
|
||||
TapasForMaskedLM,
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
)
|
||||
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from ..xlm.modeling_xlm import (
|
||||
XLMForMultipleChoice,
|
||||
@ -230,6 +236,7 @@ from .configuration_auto import (
|
||||
RobertaConfig,
|
||||
SqueezeBertConfig,
|
||||
T5Config,
|
||||
TapasConfig,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLMProphetNetConfig,
|
||||
@ -277,6 +284,7 @@ MODEL_MAPPING = OrderedDict(
|
||||
(XLMProphetNetConfig, XLMProphetNetModel),
|
||||
(ProphetNetConfig, ProphetNetModel),
|
||||
(MPNetConfig, MPNetModel),
|
||||
(TapasConfig, TapasModel),
|
||||
]
|
||||
)
|
||||
|
||||
@ -308,6 +316,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
||||
(LxmertConfig, LxmertForPreTraining),
|
||||
(FunnelConfig, FunnelForPreTraining),
|
||||
(MPNetConfig, MPNetForMaskedLM),
|
||||
(TapasConfig, TapasForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
@ -340,6 +349,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
||||
(ReformerConfig, ReformerModelWithLMHead),
|
||||
(FunnelConfig, FunnelForMaskedLM),
|
||||
(MPNetConfig, MPNetForMaskedLM),
|
||||
(TapasConfig, TapasForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
@ -386,6 +396,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
||||
(ReformerConfig, ReformerForMaskedLM),
|
||||
(FunnelConfig, FunnelForMaskedLM),
|
||||
(MPNetConfig, MPNetForMaskedLM),
|
||||
(TapasConfig, TapasForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
@ -431,6 +442,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
(CTRLConfig, CTRLForSequenceClassification),
|
||||
(TransfoXLConfig, TransfoXLForSequenceClassification),
|
||||
(MPNetConfig, MPNetForSequenceClassification),
|
||||
(TapasConfig, TapasForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
@ -455,6 +467,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
(FunnelConfig, FunnelForQuestionAnswering),
|
||||
(LxmertConfig, LxmertForQuestionAnswering),
|
||||
(MPNetConfig, MPNetForQuestionAnswering),
|
||||
(TapasConfig, TapasForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -47,6 +47,7 @@ from ..rag.tokenization_rag import RagTokenizer
|
||||
from ..retribert.tokenization_retribert import RetriBertTokenizer
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
|
||||
from ..tapas.tokenization_tapas import TapasTokenizer
|
||||
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
|
||||
from ..xlm.tokenization_xlm import XLMTokenizer
|
||||
from .configuration_auto import (
|
||||
@ -84,6 +85,7 @@ from .configuration_auto import (
|
||||
RobertaConfig,
|
||||
SqueezeBertConfig,
|
||||
T5Config,
|
||||
TapasConfig,
|
||||
TransfoXLConfig,
|
||||
XLMConfig,
|
||||
XLMProphetNetConfig,
|
||||
@ -223,6 +225,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
||||
(XLMProphetNetConfig, (XLMProphetNetTokenizer, None)),
|
||||
(ProphetNetConfig, (ProphetNetTokenizer, None)),
|
||||
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
|
||||
(TapasConfig, (TapasTokenizer, None)),
|
||||
]
|
||||
)
|
||||
|
||||
|
31
src/transformers/models/tapas/__init__.py
Normal file
31
src/transformers/models/tapas/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...file_utils import is_torch_available
|
||||
from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
|
||||
from .tokenization_tapas import TapasTokenizer
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_tapas import (
|
||||
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TapasForMaskedLM,
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
)
|
219
src/transformers/models/tapas/configuration_tapas.py
Normal file
219
src/transformers/models/tapas/configuration_tapas.py
Normal file
@ -0,0 +1,219 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Google Research and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
TAPAS configuration. Based on the BERT configuration with added parameters.
|
||||
|
||||
Hyperparameters are taken from run_task_main.py and hparam_utils.py of the original implementation. URLS:
|
||||
|
||||
- https://github.com/google-research/tapas/blob/master/tapas/run_task_main.py
|
||||
- https://github.com/google-research/tapas/blob/master/tapas/utils/hparam_utils.py
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"google/tapas-base-finetuned-sqa": "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/config.json",
|
||||
"google/tapas-base-finetuned-wtq": "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/config.json",
|
||||
"google/tapas-base-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/config.json",
|
||||
"google/tapas-base-finetuned-tabfact": "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class TapasConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.TapasModel`. It is used to
|
||||
instantiate a TAPAS model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the TAPAS `tapas-base-finetuned-sqa`
|
||||
architecture. Configuration objects inherit from :class:`~transformers.PreTrainedConfig` and can be used to control
|
||||
the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Hyperparameters additional to BERT are taken from run_task_main.py and hparam_utils.py of the original
|
||||
implementation. Original implementation available at https://github.com/google-research/tapas/tree/master.
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||
Vocabulary size of the TAPAS model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.TapasModel`.
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_sizes (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 256, 256, 2, 256, 256, 10]`):
|
||||
The vocabulary sizes of the :obj:`token_type_ids` passed when calling :class:`~transformers.TapasModel`.
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use gradient checkpointing to save memory at the expense of a slower backward pass.
|
||||
positive_label_weight (:obj:`float`, `optional`, defaults to 10.0):
|
||||
Weight for positive labels.
|
||||
num_aggregation_labels (:obj:`int`, `optional`, defaults to 0):
|
||||
The number of aggregation operators to predict.
|
||||
aggregation_loss_weight (:obj:`float`, `optional`, defaults to 1.0):
|
||||
Importance weight for the aggregation loss.
|
||||
use_answer_as_supervision (:obj:`bool`, `optional`):
|
||||
Whether to use the answer as the only supervision for aggregation examples.
|
||||
answer_loss_importance (:obj:`float`, `optional`, defaults to 1.0):
|
||||
Importance weight for the regression loss.
|
||||
use_normalized_answer_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to normalize the answer loss by the maximum of the predicted and expected value.
|
||||
huber_loss_delta: (:obj:`float`, `optional`):
|
||||
Delta parameter used to calculate the regression loss.
|
||||
temperature: (:obj:`float`, `optional`, defaults to 1.0):
|
||||
Value used to control (OR change) the skewness of cell logits probabilities.
|
||||
aggregation_temperature: (:obj:`float`, `optional`, defaults to 1.0):
|
||||
Scales aggregation logits to control the skewness of probabilities.
|
||||
use_gumbel_for_cells: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to apply Gumbel-Softmax to cell selection.
|
||||
use_gumbel_for_aggregation: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to apply Gumbel-Softmax to aggregation selection.
|
||||
average_approximation_function: (:obj:`string`, `optional`, defaults to :obj:`"ratio"`):
|
||||
Method to calculate the expected average of cells in the weak supervision case. One of :obj:`"ratio"`,
|
||||
:obj:`"first_order"` or :obj:`"second_order"`.
|
||||
cell_selection_preference: (:obj:`float`, `optional`):
|
||||
Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
|
||||
aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
|
||||
operator) is higher than this hyperparameter, then aggregation is predicted for an example.
|
||||
answer_loss_cutoff: (:obj:`float`, `optional`):
|
||||
Ignore examples with answer loss larger than cutoff.
|
||||
max_num_rows: (:obj:`int`, `optional`, defaults to 64):
|
||||
Maximum number of rows.
|
||||
max_num_columns: (:obj:`int`, `optional`, defaults to 32):
|
||||
Maximum number of columns.
|
||||
average_logits_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to average logits per cell.
|
||||
select_one_column: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to constrain the model to only select cells from a single column.
|
||||
allow_empty_column_selection: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to allow not to select any column.
|
||||
init_cell_selection_weights_to_zero: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
|
||||
reset_position_index_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to restart position indexes at every cell (i.e. use relative position embeddings).
|
||||
disable_per_token_loss: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to disable any (strong or weak) supervision on cells.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import TapasModel, TapasConfig
|
||||
>>> # Initializing a default (SQA) Tapas configuration
|
||||
>>> configuration = TapasConfig()
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = TapasModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
"""
|
||||
|
||||
model_type = "tapas"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=1024,
|
||||
type_vocab_sizes=[3, 256, 256, 2, 256, 256, 10],
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=0,
|
||||
gradient_checkpointing=False,
|
||||
positive_label_weight=10.0,
|
||||
num_aggregation_labels=0,
|
||||
aggregation_loss_weight=1.0,
|
||||
use_answer_as_supervision=None,
|
||||
answer_loss_importance=1.0,
|
||||
use_normalized_answer_loss=False,
|
||||
huber_loss_delta=None,
|
||||
temperature=1.0,
|
||||
aggregation_temperature=1.0,
|
||||
use_gumbel_for_cells=False,
|
||||
use_gumbel_for_aggregation=False,
|
||||
average_approximation_function="ratio",
|
||||
cell_selection_preference=None,
|
||||
answer_loss_cutoff=None,
|
||||
max_num_rows=64,
|
||||
max_num_columns=32,
|
||||
average_logits_per_cell=False,
|
||||
select_one_column=True,
|
||||
allow_empty_column_selection=False,
|
||||
init_cell_selection_weights_to_zero=False,
|
||||
reset_position_index_per_cell=True,
|
||||
disable_per_token_loss=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
# BERT hyperparameters (with updated max_position_embeddings and type_vocab_sizes)
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_sizes = type_vocab_sizes
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
# Fine-tuning task hyperparameters
|
||||
self.positive_label_weight = positive_label_weight
|
||||
self.num_aggregation_labels = num_aggregation_labels
|
||||
self.aggregation_loss_weight = aggregation_loss_weight
|
||||
self.use_answer_as_supervision = use_answer_as_supervision
|
||||
self.answer_loss_importance = answer_loss_importance
|
||||
self.use_normalized_answer_loss = use_normalized_answer_loss
|
||||
self.huber_loss_delta = huber_loss_delta
|
||||
self.temperature = temperature
|
||||
self.aggregation_temperature = aggregation_temperature
|
||||
self.use_gumbel_for_cells = use_gumbel_for_cells
|
||||
self.use_gumbel_for_aggregation = use_gumbel_for_aggregation
|
||||
self.average_approximation_function = average_approximation_function
|
||||
self.cell_selection_preference = cell_selection_preference
|
||||
self.answer_loss_cutoff = answer_loss_cutoff
|
||||
self.max_num_rows = max_num_rows
|
||||
self.max_num_columns = max_num_columns
|
||||
self.average_logits_per_cell = average_logits_per_cell
|
||||
self.select_one_column = select_one_column
|
||||
self.allow_empty_column_selection = allow_empty_column_selection
|
||||
self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero
|
||||
self.reset_position_index_per_cell = reset_position_index_per_cell
|
||||
self.disable_per_token_loss = disable_per_token_loss
|
@ -0,0 +1,137 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert TAPAS checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers.models.tapas.modeling_tapas import (
|
||||
TapasConfig,
|
||||
TapasForMaskedLM,
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
load_tf_weights_in_tapas,
|
||||
)
|
||||
from transformers.models.tapas.tokenization_tapas import TapasTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(
|
||||
task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path
|
||||
):
|
||||
# Initialise PyTorch model.
|
||||
# If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of
|
||||
# TapasConfig to False.
|
||||
|
||||
# initialize configuration from json file
|
||||
config = TapasConfig.from_json_file(tapas_config_file)
|
||||
# set absolute/relative position embeddings parameter
|
||||
config.reset_position_index_per_cell = reset_position_index_per_cell
|
||||
|
||||
# set remaining parameters of TapasConfig as well as the model based on the task
|
||||
if task == "SQA":
|
||||
model = TapasForQuestionAnswering(config=config)
|
||||
elif task == "WTQ":
|
||||
# run_task_main.py hparams
|
||||
config.num_aggregation_labels = 4
|
||||
config.use_answer_as_supervision = True
|
||||
# hparam_utils.py hparams
|
||||
config.answer_loss_cutoff = 0.664694
|
||||
config.cell_selection_preference = 0.207951
|
||||
config.huber_loss_delta = 0.121194
|
||||
config.init_cell_selection_weights_to_zero = True
|
||||
config.select_one_column = True
|
||||
config.allow_empty_column_selection = False
|
||||
config.temperature = 0.0352513
|
||||
|
||||
model = TapasForQuestionAnswering(config=config)
|
||||
elif task == "WIKISQL_SUPERVISED":
|
||||
# run_task_main.py hparams
|
||||
config.num_aggregation_labels = 4
|
||||
config.use_answer_as_supervision = False
|
||||
# hparam_utils.py hparams
|
||||
config.answer_loss_cutoff = 36.4519
|
||||
config.cell_selection_preference = 0.903421
|
||||
config.huber_loss_delta = 222.088
|
||||
config.init_cell_selection_weights_to_zero = True
|
||||
config.select_one_column = True
|
||||
config.allow_empty_column_selection = True
|
||||
config.temperature = 0.763141
|
||||
|
||||
model = TapasForQuestionAnswering(config=config)
|
||||
elif task == "TABFACT":
|
||||
model = TapasForSequenceClassification(config=config)
|
||||
elif task == "MLM":
|
||||
model = TapasForMaskedLM(config=config)
|
||||
elif task == "INTERMEDIATE_PRETRAINING":
|
||||
model = TapasModel(config=config)
|
||||
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_tapas(model, config, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model (weights and configuration)
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
model.save_pretrained(pytorch_dump_path[:-17])
|
||||
|
||||
# Save tokenizer files
|
||||
dir_name = r"C:\Users\niels.rogge\Documents\Python projecten\tensorflow\Tensorflow models\SQA\Base\tapas_sqa_inter_masklm_base_reset"
|
||||
tokenizer = TapasTokenizer(vocab_file=dir_name + r"\vocab.txt", model_max_length=512)
|
||||
|
||||
print("Save tokenizer files to {}".format(pytorch_dump_path))
|
||||
tokenizer.save_pretrained(pytorch_dump_path[:-17])
|
||||
|
||||
print("Used relative position embeddings:", model.config.reset_position_index_per_cell)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--task", default="SQA", type=str, help="Model task for which to convert a checkpoint. Defaults to SQA."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset_position_index_per_cell",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use relative position embeddings or not. Defaults to True.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tapas_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The config json file corresponding to the pre-trained TAPAS model. \n"
|
||||
"This specifies the model architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(
|
||||
args.task,
|
||||
args.reset_position_index_per_cell,
|
||||
args.tf_checkpoint_path,
|
||||
args.tapas_config_file,
|
||||
args.pytorch_dump_path,
|
||||
)
|
2286
src/transformers/models/tapas/modeling_tapas.py
Normal file
2286
src/transformers/models/tapas/modeling_tapas.py
Normal file
File diff suppressed because it is too large
Load Diff
2774
src/transformers/models/tapas/tokenization_tapas.py
Normal file
2774
src/transformers/models/tapas/tokenization_tapas.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -28,6 +28,8 @@ from .file_utils import (
|
||||
_datasets_available,
|
||||
_faiss_available,
|
||||
_flax_available,
|
||||
_pandas_available,
|
||||
_scatter_available,
|
||||
_sentencepiece_available,
|
||||
_tf_available,
|
||||
_tokenizers_available,
|
||||
@ -221,6 +223,27 @@ def require_tokenizers(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_pandas(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
|
||||
"""
|
||||
if not _pandas_available:
|
||||
return unittest.skip("test requires pandas")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_scatter(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
|
||||
installed.
|
||||
"""
|
||||
if not _scatter_available:
|
||||
return unittest.skip("test requires PyTorch Scatter")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
|
||||
|
@ -1867,6 +1867,45 @@ def load_tf_weights_in_t5(*args, **kwargs):
|
||||
requires_pytorch(load_tf_weights_in_t5)
|
||||
|
||||
|
||||
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TapasForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TapasForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TapasForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class TapasModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
1084
tests/test_modeling_tapas.py
Normal file
1084
tests/test_modeling_tapas.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -584,7 +584,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
# We want to have sequence 0 and sequence 1 are tagged
|
||||
# respectively with 0 and 1 token_ids
|
||||
# (regardeless of weither the model use token type ids)
|
||||
# (regardless of whether the model use token type ids)
|
||||
# We use this assumption in the QA pipeline among other place
|
||||
output = tokenizer(seq_0, return_token_type_ids=True)
|
||||
self.assertIn(0, output["token_type_ids"])
|
||||
@ -600,7 +600,7 @@ class TokenizerTesterMixin:
|
||||
|
||||
# We want to have sequence 0 and sequence 1 are tagged
|
||||
# respectively with 0 and 1 token_ids
|
||||
# (regardeless of weither the model use token type ids)
|
||||
# (regardless of whether the model use token type ids)
|
||||
# We use this assumption in the QA pipeline among other place
|
||||
output = tokenizer(seq_0)
|
||||
self.assertIn(0, output.sequence_ids())
|
||||
|
1190
tests/test_tokenization_tapas.py
Normal file
1190
tests/test_tokenization_tapas.py
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user