[WIP] Tapas v4 (tres) (#9117)

* First commit: adding all files from tapas_v3

* Fix multiple bugs including soft dependency and new structure of the library

* Improve testing by adding torch_device to inputs and adding dependency on scatter

* Use Python 3 inheritance rather than Python 2

* First draft model cards of base sized models

* Remove model cards as they are already on the hub

* Fix multiple bugs with integration tests

* All model integration tests pass

* Remove print statement

* Add test for convert_logits_to_predictions method of TapasTokenizer

* Incorporate suggestions by Google authors

* Fix remaining tests

* Change position embeddings sizes to 512 instead of 1024

* Comment out positional embedding sizes

* Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

* Added more model names

* Fix truncation when no max length is specified

* Disable torchscript test

* Make style & make quality

* Quality

* Address CI needs

* Test the Masked LM model

* Fix the masked LM model

* Truncate when overflowing

* More much needed docs improvements

* Fix some URLs

* Some more docs improvements

* Test PyTorch scatter

* Set to slow + minify

* Calm flake8 down

* First commit: adding all files from tapas_v3

* Fix multiple bugs including soft dependency and new structure of the library

* Improve testing by adding torch_device to inputs and adding dependency on scatter

* Use Python 3 inheritance rather than Python 2

* First draft model cards of base sized models

* Remove model cards as they are already on the hub

* Fix multiple bugs with integration tests

* All model integration tests pass

* Remove print statement

* Add test for convert_logits_to_predictions method of TapasTokenizer

* Incorporate suggestions by Google authors

* Fix remaining tests

* Change position embeddings sizes to 512 instead of 1024

* Comment out positional embedding sizes

* Update PRETRAINED_VOCAB_FILES_MAP and PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

* Added more model names

* Fix truncation when no max length is specified

* Disable torchscript test

* Make style & make quality

* Quality

* Address CI needs

* Test the Masked LM model

* Fix the masked LM model

* Truncate when overflowing

* More much needed docs improvements

* Fix some URLs

* Some more docs improvements

* Add add_pooling_layer argument to TapasModel

Fix comments by @sgugger and @patrickvonplaten

* Fix issue in docs + fix style and quality

* Clean up conversion script and add task parameter to TapasConfig

* Revert the task parameter of TapasConfig

Some minor fixes

* Improve conversion script and add test for absolute position embeddings

* Improve conversion script and add test for absolute position embeddings

* Fix bug with reset_position_index_per_cell arg of the conversion cli

* Add notebooks to the examples directory and fix style and quality

* Apply suggestions from code review

* Move from `nielsr/` to `google/` namespace

* Apply Sylvain's comments

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

Co-authored-by: Rogge Niels <niels.rogge@howest.be>
Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
NielsRogge 2020-12-15 23:08:49 +01:00 committed by GitHub
parent ad895af98d
commit 1551e2dc6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 8497 additions and 78 deletions

View File

@ -79,6 +79,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
@ -105,6 +106,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
@ -183,6 +185,7 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:

View File

@ -50,6 +50,7 @@ jobs:
pip install --upgrade pip
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets
pip install pandas torch-scatter -f https://pytorch-geometric.com/whl/torch-$(python -c "import torch; print(''.join(torch.__version__)")+$(python -c "import torch; print(''.join(torch.version.cuda.split('.')))").html
- name: Are GPUs recognized by our DL frameworks
run: |
@ -187,6 +188,7 @@ jobs:
pip install --upgrade pip
pip install .[torch,sklearn,testing,onnxruntime,sentencepiece]
pip install git+https://github.com/huggingface/datasets
pip install pandas torch-scatter -f https://pytorch-geometric.com/whl/torch-$(python -c "import torch; print(''.join(torch.__version__)")+$(python -c "import torch; print(''.join(torch.version.cuda.split('.')))").html
- name: Are GPUs recognized by our DL frameworks
run: |

View File

@ -222,6 +222,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
ultilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT.
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/transformers/master/model_doc/tapas.html)** released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.

View File

@ -176,19 +176,22 @@ and conversion utilities for the following models:
30. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
31. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
31. `TAPAS <https://huggingface.co/transformers/master/model_doc/tapas.html>`__ released with the paper `TAPAS: Weakly
Supervised Table Parsing via Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof
Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
32. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
32. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
33. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
33. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
34. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
34. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
35. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov.
35. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
36. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
@ -269,6 +272,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
@ -382,6 +387,7 @@ TensorFlow and/or Flax.
model_doc/roberta
model_doc/squeezebert
model_doc/t5
model_doc/tapas
model_doc/transformerxl
model_doc/xlm
model_doc/xlmprophetnet

View File

@ -0,0 +1,427 @@
TAPAS
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The TAPAS model was proposed in `TAPAS: Weakly Supervised Table Parsing via Pre-training
<https://www.aclweb.org/anthology/2020.acl-main.398>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos. It's a BERT-based model specifically designed (and pre-trained) for
answering questions about tabular data. Compared to BERT, TAPAS uses relative position embeddings and has 7 token types
that encode tabular structure. TAPAS is pre-trained on the masked language modeling (MLM) objective on a large dataset
comprising millions of tables from English Wikipedia and corresponding texts. For question answering, TAPAS has 2 heads
on top: a cell selection head and an aggregation head, for (optionally) performing aggregations (such as counting or
summing) among selected cells. TAPAS has been fine-tuned on several datasets: `SQA
<https://www.microsoft.com/en-us/download/details.aspx?id=54253>`__ (Sequential Question Answering by Microsoft), `WTQ
<https://github.com/ppasupat/WikiTableQuestions>`__ (Wiki Table Questions by Stanford University) and `WikiSQL
<https://github.com/salesforce/WikiSQL>`__ (by Salesforce). It achieves state-of-the-art on both SQA and WTQ, while
having comparable performance to SOTA on WikiSQL, with a much simpler architecture.
The abstract from the paper is the following:
*Answering natural language questions over tables is usually seen as a semantic parsing task. To alleviate the
collection cost of full logical forms, one popular approach focuses on weak supervision consisting of denotations
instead of logical forms. However, training semantic parsers from weak supervision poses difficulties, and in addition,
the generated logical forms are only used as an intermediate step prior to retrieving the denotation. In this paper, we
present TAPAS, an approach to question answering over tables without generating logical forms. TAPAS trains from weak
supervision, and predicts the denotation by selecting table cells and optionally applying a corresponding aggregation
operator to such selection. TAPAS extends BERT's architecture to encode tables as input, initializes from an effective
joint pre-training of text segments and tables crawled from Wikipedia, and is trained end-to-end. We experiment with
three different semantic parsing datasets, and find that TAPAS outperforms or rivals semantic parsing models by
improving state-of-the-art accuracy on SQA from 55.1 to 67.2 and performing on par with the state-of-the-art on WIKISQL
and WIKITQ, but with a simpler model architecture. We additionally find that transfer learning, which is trivial in our
setting, from WIKISQL to WIKITQ, yields 48.7 accuracy, 4.2 points above the state-of-the-art.*
In addition, the authors have further pre-trained TAPAS to recognize **table entailment**, by creating a balanced
dataset of millions of automatically created training examples which are learned in an intermediate step prior to
fine-tuning. The authors of TAPAS call this further pre-training intermediate pre-training (since TAPAS is first
pre-trained on MLM, and then on another dataset). They found that intermediate pre-training further improves
performance on SQA, achieving a new state-of-the-art as well as state-of-the-art on `TabFact
<https://github.com/wenhuchen/Table-Fact-Checking>`__, a large-scale dataset with 16k Wikipedia tables for table
entailment (a binary classification task). For more details, see their follow-up paper: `Understanding tables with
intermediate pre-training <https://www.aclweb.org/anthology/2020.findings-emnlp.27/>`__ by Julian Martin Eisenschlos,
Syrine Krichene and Thomas Müller.
The original code can be found `here <https://github.com/google-research/tapas>`__.
Tips:
- TAPAS is a model that uses relative position embeddings by default (restarting the position embeddings at every cell
of the table). Note that this is something that was added after the publication of the original TAPAS paper.
According to the authors, this usually results in a slightly better performance, and allows you to encode longer
sequences without running out of embeddings. This is reflected in the ``reset_position_index_per_cell`` parameter of
:class:`~transformers.TapasConfig`, which is set to ``True`` by default. The default versions of the models available
in the `model hub <https://huggingface.co/models?search=tapas>`_ all use relative position embeddings. You can still
use the ones with absolute position embeddings by passing in an additional argument ``revision="no_reset"`` when
calling the ``.from_pretrained()`` method. Note that it's usually advised to pad the inputs on the right rather than
the left.
- TAPAS is based on BERT, so ``TAPAS-base`` for example corresponds to a ``BERT-base`` architecture. Of course,
TAPAS-large will result in the best performance (the results reported in the paper are from TAPAS-large). Results of
the various sized models are shown on the `original Github repository <https://github.com/google-research/tapas>`_.
- TAPAS has checkpoints fine-tuned on SQA, which are capable of answering questions related to a table in a
conversational set-up. This means that you can ask follow-up questions such as "what is his age?" related to the
previous question. Note that the forward pass of TAPAS is a bit different in case of a conversational set-up: in that
case, you have to feed every table-question pair one by one to the model, such that the `prev_labels` token type ids
can be overwritten by the predicted `labels` of the model to the previous question. See "Usage" section for more
info.
- TAPAS is similar to BERT and therefore relies on the masked language modeling (MLM) objective. It is therefore
efficient at predicting masked tokens and at NLU in general, but is not optimal for text generation. Models trained
with a causal language modeling (CLM) objective are better in that regard.
Usage: fine-tuning
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here we explain how you can fine-tune :class:`~transformers.TapasForQuestionAnswering` on your own dataset.
**STEP 1: Choose one of the 3 ways in which you can use TAPAS - or experiment**
Basically, there are 3 different ways in which one can fine-tune :class:`~transformers.TapasForQuestionAnswering`,
corresponding to the different datasets on which Tapas was fine-tuned:
1. SQA: if you're interested in asking follow-up questions related to a table, in a conversational set-up. For example
if you first ask "what's the name of the first actor?" then you can ask a follow-up question such as "how old is
he?". Here, questions do not involve any aggregation (all questions are cell selection questions).
2. WTQ: if you're not interested in asking questions in a conversational set-up, but rather just asking questions
related to a table, which might involve aggregation, such as counting a number of rows, summing up cell values or
averaging cell values. You can then for example ask "what's the total number of goals Cristiano Ronaldo made in his
career?". This case is also called **weak supervision**, since the model itself must learn the appropriate
aggregation operator (SUM/COUNT/AVERAGE/NONE) given only the answer to the question as supervision.
3. WikiSQL-supervised: this dataset is based on WikiSQL with the model being given the ground truth aggregation
operator during training. This is also called **strong supervision**. Here, learning the appropriate aggregation
operator is much easier.
To summarize:
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
| **Task** | **Example dataset** | **Description** |
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
| Conversational | SQA | Conversational, only cell selection questions |
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
| Weak supervision for aggregation | WTQ | Questions might involve aggregation, and the model must learn this given only the answer as supervision |
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
| Strong supervision for aggregation | WikiSQL-supervised | Questions might involve aggregation, and the model must learn this given the gold aggregation operator |
+------------------------------------+----------------------+-------------------------------------------------------------------------------------------------------------------+
Initializing a model with a pre-trained base and randomly initialized classification heads from the model hub can be
done as follows (be sure to have installed the `torch-scatter dependency <https://github.com/rusty1s/pytorch_scatter>`_
for your environment):
.. code-block::
>>> from transformers import TapasConfig, TapasForQuestionAnswering
>>> # for example, the base sized model with default SQA configuration
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base')
>>> # or, the base sized model with WTQ configuration
>>> config = TapasConfig.from_pretrained('google/tapas-base-finetuned-wtq')
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base', config=config)
>>> # or, the base sized model with WikiSQL configuration
>>> config = TapasConfig('google-base-finetuned-wikisql-supervised')
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base', config=config)
Of course, you don't necessarily have to follow one of these three ways in which TAPAS was fine-tuned. You can also
experiment by defining any hyperparameters you want when initializing :class:`~transformers.TapasConfig`, and then
create a :class:`~transformers.TapasForQuestionAnswering` based on that configuration. For example, if you have a
dataset that has both conversational questions and questions that might involve aggregation, then you can do it this
way. Here's an example:
.. code-block::
>>> from transformers import TapasConfig, TapasForQuestionAnswering
>>> # you can initialize the classification heads any way you want (see docs of TapasConfig)
>>> config = TapasConfig(num_aggregation_labels=3, average_logits_per_cell=True, select_one_column=False)
>>> # initializing the pre-trained base sized model with our custom classification heads
>>> model = TapasForQuestionAnswering.from_pretrained('google/tapas-base', config=config)
What you can also do is start from an already fine-tuned checkpoint. A note here is that the already fine-tuned
checkpoint on WTQ has some issues due to the L2-loss which is somewhat brittle. See `here
<https://github.com/google-research/tapas/issues/91#issuecomment-735719340>`__ for more info.
For a list of all pre-trained and fine-tuned TAPAS checkpoints available in the HuggingFace model hub, see `here
<https://huggingface.co/models?search=tapas>`__.
**STEP 2: Prepare your data in the SQA format**
Second, no matter what you picked above, you should prepare your dataset in the `SQA format
<https://www.microsoft.com/en-us/download/details.aspx?id=54253>`__. This format is a TSV/CSV file with the following
columns:
- ``id``: optional, id of the table-question pair, for bookkeeping purposes.
- ``annotator``: optional, id of the person who annotated the table-question pair, for bookkeeping purposes.
- ``position``: integer indicating if the question is the first, second, third,... related to the table. Only required
in case of conversational setup (SQA). You don't need this column in case you're going for WTQ/WikiSQL-supervised.
- ``question``: string
- ``table_file``: string, name of a csv file containing the tabular data
- ``answer_coordinates``: list of one or more tuples (each tuple being a cell coordinate, i.e. row, column pair that is
part of the answer)
- ``answer_text``: list of one or more strings (each string being a cell value that is part of the answer)
- ``aggregation_label``: index of the aggregation operator. Only required in case of strong supervision for aggregation
(the WikiSQL-supervised case)
- ``float_answer``: the float answer to the question, if there is one (np.nan if there isn't). Only required in case of
weak supervision for aggregation (such as WTQ and WikiSQL)
The tables themselves should be present in a folder, each table being a separate csv file. Note that the authors of the
TAPAS algorithm used conversion scripts with some automated logic to convert the other datasets (WTQ, WikiSQL) into the
SQA format. The author explains this `here
<https://github.com/google-research/tapas/issues/50#issuecomment-705465960>`__. Interestingly, these conversion scripts
are not perfect (the ``answer_coordinates`` and ``float_answer`` fields are populated based on the ``answer_text``),
meaning that WTQ and WikiSQL results could actually be improved.
**STEP 3: Convert your data into PyTorch tensors using TapasTokenizer**
Third, given that you've prepared your data in this TSV/CSV format (and corresponding CSV files containing the tabular
data), you can then use :class:`~transformers.TapasTokenizer` to convert table-question pairs into :obj:`input_ids`,
:obj:`attention_mask`, :obj:`token_type_ids` and so on. Again, based on which of the three cases you picked above,
:class:`~transformers.TapasForQuestionAnswering` requires different inputs to be fine-tuned:
+------------------------------------+----------------------------------------------------------------------------------------------+
| **Task** | **Required inputs** |
+------------------------------------+----------------------------------------------------------------------------------------------+
| Conversational | ``input_ids``, ``attention_mask``, ``token_type_ids``, ``labels`` |
+------------------------------------+----------------------------------------------------------------------------------------------+
| Weak supervision for aggregation | ``input_ids``, ``attention_mask``, ``token_type_ids``, ``labels``, ``numeric_values``, |
| | ``numeric_values_scale``, ``float_answer`` |
+------------------------------------+----------------------------------------------------------------------------------------------+
| Strong supervision for aggregation | ``input ids``, ``attention mask``, ``token type ids``, ``labels``, ``aggregation_labels`` |
+------------------------------------+----------------------------------------------------------------------------------------------+
:class:`~transformers.TapasTokenizer` creates the ``labels``, ``numeric_values`` and ``numeric_values_scale`` based on
the ``answer_coordinates`` and ``answer_text`` columns of the TSV file. The ``float_answer`` and ``aggregation_labels``
are already in the TSV file of step 2. Here's an example:
.. code-block::
>>> from transformers import TapasTokenizer
>>> import pandas as pd
>>> model_name = 'google/tapas-base'
>>> tokenizer = TapasTokenizer.from_pretrained(model_name)
>>> data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], 'Number of movies': ["87", "53", "69"]}
>>> queries = ["What is the name of the first actor?", "How many movies has George Clooney played in?", "What is the total number of movies?"]
>>> answer_coordinates = [[(0, 0)], [(2, 1)], [(0, 1), (1, 1), (2, 1)]]
>>> answer_text = [["Brad Pitt"], ["69"], ["209"]]
>>> table = pd.DataFrame.from_dict(data)
>>> inputs = tokenizer(table=table, queries=queries, answer_coordinates=answer_coordinates, answer_text=answer_text, padding='max_length', return_tensors='pt')
>>> inputs
{'input_ids': tensor([[ ... ]]), 'attention_mask': tensor([[...]]), 'token_type_ids': tensor([[[...]]]),
'numeric_values': tensor([[ ... ]]), 'numeric_values_scale: tensor([[ ... ]]), labels: tensor([[ ... ]])}
Note that :class:`~transformers.TapasTokenizer` expects the data of the table to be **text-only**. You can use
``.astype(str)`` on a dataframe to turn it into text-only data. Of course, this only shows how to encode a single
training example. It is advised to create a PyTorch dataset and a corresponding dataloader:
.. code-block::
>>> import torch
>>> import pandas as pd
>>> tsv_path = "your_path_to_the_tsv_file"
>>> table_csv_path = "your_path_to_a_directory_containing_all_csv_files"
>>> class TableDataset(torch.utils.data.Dataset):
... def __init__(self, data, tokenizer):
... self.data = data
... self.tokenizer = tokenizer
...
... def __getitem__(self, idx):
... item = data.iloc[idx]
... table = pd.read_csv(table_csv_path + item.table_file).astype(str) # be sure to make your table data text only
... encoding = self.tokenizer(table=table,
... queries=item.question,
... answer_coordinates=item.answer_coordinates,
... answer_text=item.answer_text,
... truncation=True,
... padding="max_length",
... return_tensors="pt"
... )
... # remove the batch dimension which the tokenizer adds by default
... encoding = {key: val.squeeze(0) for key, val in encoding.items()}
... # add the float_answer which is also required (weak supervision for aggregation case)
... encoding["float_answer"] = torch.tensor(item.float_answer)
... return encoding
...
... def __len__(self):
... return len(self.data)
>>> data = pd.read_csv(tsv_path, sep='\t')
>>> train_dataset = TableDataset(data, tokenizer)
>>> train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
Note that here, we encode each table-question pair independently. This is fine as long as your dataset is **not
conversational**. In case your dataset involves conversational questions (such as in SQA), then you should first group
together the ``queries``, ``answer_coordinates`` and ``answer_text`` per table (in the order of their ``position``
index) and batch encode each table with its questions. This will make sure that the ``prev_labels`` token types (see
docs of :class:`~transformers.TapasTokenizer`) are set correctly. See `this notebook
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__
for more info.
**STEP 4: Train (fine-tune) TapasForQuestionAnswering**
You can then fine-tune :class:`~transformers.TapasForQuestionAnswering` using native PyTorch as follows (shown here for
the weak supervision for aggregation case):
.. code-block::
>>> from transformers import TapasConfig, TapasForQuestionAnswering, AdamW
>>> # this is the default WTQ configuration
>>> config = TapasConfig(
... num_aggregation_labels = 4,
... use_answer_as_supervision = True,
... answer_loss_cutoff = 0.664694,
... cell_selection_preference = 0.207951,
... huber_loss_delta = 0.121194,
... init_cell_selection_weights_to_zero = True,
... select_one_column = True,
... allow_empty_column_selection = False,
... temperature = 0.0352513,
... )
>>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base", config=config)
>>> optimizer = AdamW(model.parameters(), lr=5e-5)
>>> for epoch in range(2): # loop over the dataset multiple times
... for idx, batch in enumerate(train_dataloader):
... # get the inputs;
... input_ids = batch["input_ids"]
... attention_mask = batch["attention_mask"]
... token_type_ids = batch["token_type_ids"]
... labels = batch["labels"]
... numeric_values = batch["numeric_values"]
... numeric_values_scale = batch["numeric_values_scale"]
... float_answer = batch["float_answer"]
... # zero the parameter gradients
... optimizer.zero_grad()
... # forward + backward + optimize
... outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
... labels=labels, numeric_values=numeric_values, numeric_values_scale=numeric_values_scale,
... float_answer=float_answer)
... loss = outputs.loss
... loss.backward()
... optimizer.step()
Usage: inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here we explain how you can use :class:`~transformers.TapasForQuestionAnswering` for inference (i.e. making predictions
on new data). For inference, only ``input_ids``, ``attention_mask`` and ``token_type_ids`` (which you can obtain using
:class:`~transformers.TapasTokenizer`) have to be provided to the model to obtain the logits. Next, you can use the
handy ``convert_logits_to_predictions`` method of :class:`~transformers.TapasTokenizer` to convert these into predicted
coordinates and optional aggregation indices.
However, note that inference is **different** depending on whether or not the setup is conversational. In a
non-conversational set-up, inference can be done in parallel on all table-question pairs of a batch. Here's an example
of that:
.. code-block::
>>> from transformers import TapasTokenizer, TapasForQuestionAnswering
>>> import pandas as pd
>>> model_name = 'google/tapas-base-finetuned-wtq'
>>> model = TapasForQuestionAnswering.from_pretrained(model_name)
>>> tokenizer = TapasTokenizer.from_pretrained(model_name)
>>> data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], 'Number of movies': ["87", "53", "69"]}
>>> queries = ["What is the name of the first actor?", "How many movies has George Clooney played in?", "What is the total number of movies?"]
>>> table = pd.DataFrame.from_dict(data)
>>> inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
>>> outputs = model(**inputs)
>>> predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
... inputs,
... outputs.logits.detach(),
... outputs.logits_aggregation.detach()
...)
>>> # let's print out the results:
>>> id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
>>> aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
>>> answers = []
>>> for coordinates in predicted_answer_coordinates:
... if len(coordinates) == 1:
... # only a single cell:
... answers.append(table.iat[coordinates[0]])
... else:
... # multiple cells
... cell_values = []
... for coordinate in coordinates:
... cell_values.append(table.iat[coordinate])
... answers.append(", ".join(cell_values))
>>> display(table)
>>> print("")
>>> for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
... print(query)
... if predicted_agg == "NONE":
... print("Predicted answer: " + answer)
... else:
... print("Predicted answer: " + predicted_agg + " > " + answer)
What is the name of the first actor?
Predicted answer: Brad Pitt
How many movies has George Clooney played in?
Predicted answer: COUNT > 69
What is the total number of movies?
Predicted answer: SUM > 87, 53, 69
In case of a conversational set-up, then each table-question pair must be provided **sequentially** to the model, such
that the ``prev_labels`` token types can be overwritten by the predicted ``labels`` of the previous table-question
pair. Again, more info can be found in `this notebook
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__.
Tapas specific outputs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.models.tapas.modeling_tapas.TableQuestionAnsweringOutput
:members:
TapasConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TapasConfig
:members:
TapasTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TapasTokenizer
:members: __call__, convert_logits_to_predictions, save_vocabulary
TapasModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TapasModel
:members: forward
TapasForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TapasForMaskedLM
:members: forward
TapasForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TapasForSequenceClassification
:members: forward
TapasForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TapasForQuestionAnswering
:members: forward

View File

@ -0,0 +1,123 @@
---
language: en
tags:
- tapas
- masked-lm
license: apache-2.0
---
# TAPAS base model
This model corresponds to the `tapas_inter_masklm_base_reset` checkpoint of the [original Github repository](https://github.com/google-research/tapas).
Disclaimer: The team releasing TAPAS did not write a model card for this model so this model card has been written by
the Hugging Face team and contributors.
## Model description
TAPAS is a BERT-like transformers model pretrained on a large corpus of English data from Wikipedia in a self-supervised fashion.
This means it was pretrained on the raw tables and associated texts only, with no humans labelling them in any way (which is why it
can use lots of publicly available data) with an automatic process to generate inputs and labels from those texts. More precisely, it
was pretrained with two objectives:
- Masked language modeling (MLM): taking a (flattened) table and associated context, the model randomly masks 15% of the words in
the input, then runs the entire (partially masked) sequence through the model. The model then has to predict the masked words.
This is different from traditional recurrent neural networks (RNNs) that usually see the words one after the other,
or from autoregressive models like GPT which internally mask the future tokens. It allows the model to learn a bidirectional
representation of a table and associated text.
- Intermediate pre-training: to encourage numerical reasoning on tables, the authors additionally pre-trained the model by creating
a balanced dataset of millions of syntactically created training examples. Here, the model must predict (classify) whether a sentence
is supported or refuted by the contents of a table. The training examples are created based on synthetic as well as counterfactual statements.
This way, the model learns an inner representation of the English language used in tables and associated texts, which can then be used
to extract features useful for downstream tasks such as answering questions about a table, or determining whether a sentence is entailed
or refuted by the contents of a table. Fine-tuning is done by adding classification heads on top of the pre-trained model, and then jointly
train the randomly initialized classification heads with the base model on a labelled dataset.
## Intended uses & limitations
You can use the raw model for masked language modeling, but it's mostly intended to be fine-tuned on a downstream task.
See the [model hub](https://huggingface.co/models?filter=tapas) to look for fine-tuned versions on a task that interests you.
Here is how to use this model to get the features of a given table-text pair in PyTorch:
```python
from transformers import TapasTokenizer, TapasModel
import pandas as pd
tokenizer = TapasTokenizer.from_pretrained('tapase-base')
model = TapasModel.from_pretrained("tapas-base")
data = {'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
'Age': ["56", "45", "59"],
'Number of movies': ["87", "53", "69"]
}
table = pd.DataFrame.from_dict(data)
queries = ["How many movies has George Clooney played in?"]
text = "Replace me by any text you'd like."
encoded_input = tokenizer(table=table, queries=queries, return_tensors='pt')
output = model(**encoded_input)
```
## Training data
For masked language modeling (MLM), a collection of 6.2 million tables was extracted from English Wikipedia: 3.3M of class [Infobox](https://en.wikipedia.org/wiki/Help:Infobox)
and 2.9M of class WikiTable. The author only considered tables with at most 500 cells. As a proxy for questions that appear in the
downstream tasks, the authros extracted the table caption, article title, article description, segment title and text of the segment
the table occurs in as relevant text snippets. In this way, 21.3M snippets were created. For more info, see the original [TAPAS paper](https://www.aclweb.org/anthology/2020.acl-main.398.pdf).
For intermediate pre-training, 2 tasks are introduced: one based on synthetic and the other from counterfactual statements. The first one
generates a sentence by sampling from a set of logical expressions that filter, combine and compare the information on the table, which is
required in table entailment (e.g., knowing that Gerald Ford is taller than the average president requires summing
all presidents and dividing by the number of presidents). The second one corrupts sentences about tables appearing on Wikipedia by swapping
entities for plausible alternatives. Examples of the two tasks can be seen in Figure 1. The procedure is described in detail in section 3 of
the [TAPAS follow-up paper](https://www.aclweb.org/anthology/2020.findings-emnlp.27.pdf).
## Training procedure
### Preprocessing
The texts are lowercased and tokenized using WordPiece and a vocabulary size of 30,000. The inputs of the model are
then of the form:
```
[CLS] Context [SEP] Flattened table [SEP]
```
The details of the masking procedure for each sequence are the following:
- 15% of the tokens are masked.
- In 80% of the cases, the masked tokens are replaced by `[MASK]`.
- In 10% of the cases, the masked tokens are replaced by a random token (different) from the one they replace.
- In the 10% remaining cases, the masked tokens are left as is.
The details of the creation of the synthetic and counterfactual examples can be found in the [follow-up paper](https://arxiv.org/abs/2010.00571).
### Pretraining
The model was trained on 32 Cloud TPU v3 cores for one million steps with maximum sequence length 512 and batch size of 512.
In this setup, pre-training takes around 3 days. The optimizer used is Adam with a learning rate of 5e-5, and a warmup ratio
of 0.10.
### BibTeX entry and citation info
```bibtex
@misc{herzig2020tapas,
title={TAPAS: Weakly Supervised Table Parsing via Pre-training},
author={Jonathan Herzig and Paweł Krzysztof Nowak and Thomas Müller and Francesco Piccinno and Julian Martin Eisenschlos},
year={2020},
eprint={2004.02349},
archivePrefix={arXiv},
primaryClass={cs.IR}
}
```
```bibtex
@misc{eisenschlos2020understanding,
title={Understanding tables with intermediate pre-training},
author={Julian Martin Eisenschlos and Syrine Krichene and Thomas Müller},
year={2020},
eprint={2010.00571},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

View File

@ -1,71 +1,73 @@
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# 🤗 Transformers Notebooks
You can find here a list of the official notebooks provided by Hugging Face.
Also, we would like to list here interesting content created by the community.
If you wrote some notebook(s) leveraging 🤗 Transformers and would like be listed here, please open a
Pull Request so it can be included under the Community notebooks.
## Hugging Face's notebooks 🤗
| Notebook | Description | |
|:----------|:-------------|------:|
| [Getting Started Tokenizers](https://github.com/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) | How to train and use your very own tokenizer |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) |
| [Getting Started Transformers](https://github.com/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) | How to easily start using transformers | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) |
| [How to use Pipelines](https://github.com/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) | Simple and efficient way to use State-of-the-Art models on downstream tasks through transformers | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) |
| [How to train a language model](https://github.com/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)|
| [How to generate text](https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)|
| [How to export model to ONNX](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb) | Highlight how to export and run inference workloads through ONNX |
| [How to use Benchmarks](https://github.com/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb) | How to benchmark models with transformers | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb)|
| [Reformer](https://github.com/huggingface/blog/blob/master/notebooks/03_reformer.ipynb) | How Reformer pushes the limits of language modeling | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/blog/blob/master/notebooks/03_reformer.ipynb)|
## Community notebooks:
| Notebook | Description | Author | |
|:----------|:-------------|:-------------|------:|
| [Train T5 in Tensorflow 2 ](https://github.com/snapthat/TF-T5-text-to-text) | How to train T5 for any task using Tensorflow 2. This notebook demonstrates a Question & Answer task implemented in Tensorflow 2 using SQUAD | [Muhammad Harris](https://github.com/HarrisDePerceptron) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb) |
| [Train T5 on TPU](https://github.com/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb) | How to train T5 on SQUAD with Transformers and Nlp | [Suraj Patil](https://github.com/patil-suraj) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil) |
| [Fine-tune T5 for Classification and Multiple Choice](https://github.com/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) | How to fine-tune T5 for classification and multiple choice tasks using a text-to-text format with PyTorch Lightning | [Suraj Patil](https://github.com/patil-suraj) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) |
| [Fine-tune DialoGPT on New Datasets and Languages](https://github.com/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) | How to fine-tune the DialoGPT model on a new dataset for open-dialog conversational chatbots | [Nathan Cooper](https://github.com/ncoop57) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) |
| [Long Sequence Modeling with Reformer](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) | How to train on sequences as long as 500,000 tokens with Reformer | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) |
| [Fine-tune BART for Summarization](https://github.com/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) | How to fine-tune BART for summarization with fastai using blurr | [Wayde Gilliam](https://ohmeow.com/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) |
| [Fine-tune a pre-trained Transformer on anyone's tweets](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) | How to generate tweets in the style of your favorite Twitter account by fine-tune a GPT-2 model | [Boris Dayma](https://github.com/borisdayma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) |
| [A Step by Step Guide to Tracking Hugging Face Model Performance](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) | A quick tutorial for training NLP models with HuggingFace and & visualizing their performance with Weights & Biases | [Jack Morris](https://github.com/jxmorris12) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) |
| [Pretrain Longformer](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) | How to build a "long" version of existing pretrained models | [Iz Beltagy](https://beltagy.net) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) |
| [Fine-tune Longformer for QA](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) | How to fine-tune longformer model for QA task | [Suraj Patil](https://github.com/patil-suraj) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) |
| [Evaluate Model with 🤗nlp](https://github.com/patrickvonplaten/notebooks/blob/master/How_to_evaluate_Longformer_on_TriviaQA_using_NLP.ipynb) | How to evaluate longformer on TriviaQA with `nlp` | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m7eTGlPmLRgoPkkA7rkhQdZ9ydpmsdLE?usp=sharing) |
| [Fine-tune T5 for Sentiment Span Extraction](https://github.com/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) | How to fine-tune T5 for sentiment span extraction using a text-to-text format with PyTorch Lightning | [Lorenzo Ampil](https://github.com/enzoampil) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) |
| [Fine-tune DistilBert for Multiclass Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb) | How to fine-tune DistilBert for multiclass classification with PyTorch | [Abhishek Kumar Mishra](https://github.com/abhimishra91) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb)|
|[Fine-tune BERT for Multi-label Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|How to fine-tune BERT for multi-label classification using PyTorch|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|
|[Fine-tune T5 for Summarization](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|How to fine-tune T5 for summarization in PyTorch and track experiments with WandB|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|
|[Speed up Fine-Tuning in Transformers with Dynamic Padding / Bucketing](https://github.com/ELS-RD/transformers-notebook/blob/master/Divide_Hugging_Face_Transformers_training_time_by_2_or_more.ipynb)|How to speed up fine-tuning by a factor of 2 using dynamic padding / bucketing|[Michael Benesty](https://github.com/pommedeterresautee) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CBfRU1zbfu7-ijiOqAAQUA-RJaxfcJoO?usp=sharing)|
|[Pretrain Reformer for Masked Language Modeling](https://github.com/patrickvonplaten/notebooks/blob/master/Reformer_For_Masked_LM.ipynb)| How to train a Reformer model with bi-directional self-attention layers | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tzzh0i8PgDQGV3SMFUGxM7_gGae3K-uW?usp=sharing)|
|[Expand and Fine Tune Sci-BERT](https://github.com/lordtt13/word-embeddings/blob/master/COVID-19%20Research%20Data/COVID-SciBERT.ipynb)| How to increase vocabulary of a pretrained SciBERT model from AllenAI on the CORD dataset and pipeline it. | [Tanmay Thakur](https://github.com/lordtt13) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rqAR40goxbAfez1xvF3hBJphSCsvXmh8)|
|[Fine-tune Electra and interpret with Integrated Gradients](https://github.com/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb) | How to fine-tune Electra for sentiment analysis and interpret predictions with Captum Integrated Gradients | [Eliza Szczechla](https://elsanns.github.io) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb)|
|[fine-tune a non-English GPT-2 Model with Trainer class](https://github.com/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb) | How to fine-tune a non-English GPT-2 Model with Trainer class | [Philipp Schmid](https://www.philschmid.de) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb)|
|[Fine-tune a DistilBERT Model for Multi Label Classification task](https://github.com/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb) | How to fine-tune a DistilBERT Model for Multi Label Classification task | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb)|
|[Fine-tune ALBERT for sentence-pair classification](https://github.com/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb) | How to fine-tune an ALBERT model or another BERT-based model for the sentence-pair classification task | [Nadir El Manouzi](https://github.com/NadirEM) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb)|
|[Fine-tune Roberta for sentiment analysis](https://github.com/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb) | How to fine-tune an Roberta model for sentiment analysis | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb)|
|[Evaluating Question Generation Models](https://github.com/flexudy-pipe/qugeev) | How accurate are the answers to questions generated by your seq2seq transformer model? | [Pascal Zoleko](https://github.com/zolekode) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1bpsSqCQU-iw_5nNoRm_crPq6FRuJthq_?usp=sharing)|
|[Classify text with DistilBERT and Tensorflow](https://github.com/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb) | How to fine-tune DistilBERT for text classification in TensorFlow | [Peter Bayerle](https://github.com/peterbayerle) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb)|
|[Leverage BERT for Encoder-Decoder Summarization on CNN/Dailymail](https://github.com/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb) | How to warm-start a *EncoderDecoderModel* with a *bert-base-uncased* checkpoint for summarization on CNN/Dailymail | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb)|
|[Leverage RoBERTa for Encoder-Decoder Summarization on BBC XSum](https://github.com/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb) | How to warm-start a shared *EncoderDecoderModel* with a *roberta-base* checkpoint for summarization on BBC/XSum | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb)|
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# 🤗 Transformers Notebooks
You can find here a list of the official notebooks provided by Hugging Face.
Also, we would like to list here interesting content created by the community.
If you wrote some notebook(s) leveraging 🤗 Transformers and would like be listed here, please open a
Pull Request so it can be included under the Community notebooks.
## Hugging Face's notebooks 🤗
| Notebook | Description | |
|:----------|:-------------|------:|
| [Getting Started Tokenizers](https://github.com/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) | How to train and use your very own tokenizer |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/01-training-tokenizers.ipynb) |
| [Getting Started Transformers](https://github.com/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) | How to easily start using transformers | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/02-transformers.ipynb) |
| [How to use Pipelines](https://github.com/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) | Simple and efficient way to use State-of-the-Art models on downstream tasks through transformers | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/03-pipelines.ipynb) |
| [How to train a language model](https://github.com/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)| Highlight all the steps to effectively train Transformer model on custom data | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)|
| [How to generate text](https://github.com/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)| How to use different decoding methods for language generation with transformers | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)|
| [How to export model to ONNX](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb) | Highlight how to export and run inference workloads through ONNX |
| [How to use Benchmarks](https://github.com/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb) | How to benchmark models with transformers | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb)|
| [Reformer](https://github.com/huggingface/blog/blob/master/notebooks/03_reformer.ipynb) | How Reformer pushes the limits of language modeling | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/blog/blob/master/notebooks/03_reformer.ipynb)|
## Community notebooks:
| Notebook | Description | Author | |
|:----------|:-------------|:-------------|------:|
| [Train T5 in Tensorflow 2 ](https://github.com/snapthat/TF-T5-text-to-text) | How to train T5 for any task using Tensorflow 2. This notebook demonstrates a Question & Answer task implemented in Tensorflow 2 using SQUAD | [Muhammad Harris](https://github.com/HarrisDePerceptron) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb) |
| [Train T5 on TPU](https://github.com/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb) | How to train T5 on SQUAD with Transformers and Nlp | [Suraj Patil](https://github.com/patil-suraj) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb#scrollTo=QLGiFCDqvuil) |
| [Fine-tune T5 for Classification and Multiple Choice](https://github.com/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) | How to fine-tune T5 for classification and multiple choice tasks using a text-to-text format with PyTorch Lightning | [Suraj Patil](https://github.com/patil-suraj) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/exploring-T5/blob/master/t5_fine_tuning.ipynb) |
| [Fine-tune DialoGPT on New Datasets and Languages](https://github.com/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) | How to fine-tune the DialoGPT model on a new dataset for open-dialog conversational chatbots | [Nathan Cooper](https://github.com/ncoop57) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ncoop57/i-am-a-nerd/blob/master/_notebooks/2020-05-12-chatbot-part-1.ipynb) |
| [Long Sequence Modeling with Reformer](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) | How to train on sequences as long as 500,000 tokens with Reformer | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb) |
| [Fine-tune BART for Summarization](https://github.com/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) | How to fine-tune BART for summarization with fastai using blurr | [Wayde Gilliam](https://ohmeow.com/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ohmeow/ohmeow_website/blob/master/_notebooks/2020-05-23-text-generation-with-blurr.ipynb) |
| [Fine-tune a pre-trained Transformer on anyone's tweets](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) | How to generate tweets in the style of your favorite Twitter account by fine-tune a GPT-2 model | [Boris Dayma](https://github.com/borisdayma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb) |
| [A Step by Step Guide to Tracking Hugging Face Model Performance](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) | A quick tutorial for training NLP models with HuggingFace and & visualizing their performance with Weights & Biases | [Jack Morris](https://github.com/jxmorris12) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1NEiqNPhiouu2pPwDAVeFoN4-vTYMz9F8) |
| [Pretrain Longformer](https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) | How to build a "long" version of existing pretrained models | [Iz Beltagy](https://beltagy.net) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb) |
| [Fine-tune Longformer for QA](https://github.com/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) | How to fine-tune longformer model for QA task | [Suraj Patil](https://github.com/patil-suraj) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/longformer_qa_training.ipynb) |
| [Evaluate Model with 🤗nlp](https://github.com/patrickvonplaten/notebooks/blob/master/How_to_evaluate_Longformer_on_TriviaQA_using_NLP.ipynb) | How to evaluate longformer on TriviaQA with `nlp` | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1m7eTGlPmLRgoPkkA7rkhQdZ9ydpmsdLE?usp=sharing) |
| [Fine-tune T5 for Sentiment Span Extraction](https://github.com/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) | How to fine-tune T5 for sentiment span extraction using a text-to-text format with PyTorch Lightning | [Lorenzo Ampil](https://github.com/enzoampil) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb) |
| [Fine-tune DistilBert for Multiclass Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb) | How to fine-tune DistilBert for multiclass classification with PyTorch | [Abhishek Kumar Mishra](https://github.com/abhimishra91) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multiclass_classification.ipynb)|
|[Fine-tune BERT for Multi-label Classification](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|How to fine-tune BERT for multi-label classification using PyTorch|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_multi_label_classification.ipynb)|
|[Fine-tune T5 for Summarization](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|How to fine-tune T5 for summarization in PyTorch and track experiments with WandB|[Abhishek Kumar Mishra](https://github.com/abhimishra91) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)|
|[Speed up Fine-Tuning in Transformers with Dynamic Padding / Bucketing](https://github.com/ELS-RD/transformers-notebook/blob/master/Divide_Hugging_Face_Transformers_training_time_by_2_or_more.ipynb)|How to speed up fine-tuning by a factor of 2 using dynamic padding / bucketing|[Michael Benesty](https://github.com/pommedeterresautee) |[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CBfRU1zbfu7-ijiOqAAQUA-RJaxfcJoO?usp=sharing)|
|[Pretrain Reformer for Masked Language Modeling](https://github.com/patrickvonplaten/notebooks/blob/master/Reformer_For_Masked_LM.ipynb)| How to train a Reformer model with bi-directional self-attention layers | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tzzh0i8PgDQGV3SMFUGxM7_gGae3K-uW?usp=sharing)|
|[Expand and Fine Tune Sci-BERT](https://github.com/lordtt13/word-embeddings/blob/master/COVID-19%20Research%20Data/COVID-SciBERT.ipynb)| How to increase vocabulary of a pretrained SciBERT model from AllenAI on the CORD dataset and pipeline it. | [Tanmay Thakur](https://github.com/lordtt13) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rqAR40goxbAfez1xvF3hBJphSCsvXmh8)|
|[Fine-tune Electra and interpret with Integrated Gradients](https://github.com/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb) | How to fine-tune Electra for sentiment analysis and interpret predictions with Captum Integrated Gradients | [Eliza Szczechla](https://elsanns.github.io) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elsanns/xai-nlp-notebooks/blob/master/electra_fine_tune_interpret_captum_ig.ipynb)|
|[fine-tune a non-English GPT-2 Model with Trainer class](https://github.com/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb) | How to fine-tune a non-English GPT-2 Model with Trainer class | [Philipp Schmid](https://www.philschmid.de) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb)|
|[Fine-tune a DistilBERT Model for Multi Label Classification task](https://github.com/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb) | How to fine-tune a DistilBERT Model for Multi Label Classification task | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhavalTaunk08/Transformers_scripts/blob/master/Transformers_multilabel_distilbert.ipynb)|
|[Fine-tune ALBERT for sentence-pair classification](https://github.com/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb) | How to fine-tune an ALBERT model or another BERT-based model for the sentence-pair classification task | [Nadir El Manouzi](https://github.com/NadirEM) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NadirEM/nlp-notebooks/blob/master/Fine_tune_ALBERT_sentence_pair_classification.ipynb)|
|[Fine-tune Roberta for sentiment analysis](https://github.com/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb) | How to fine-tune an Roberta model for sentiment analysis | [Dhaval Taunk](https://github.com/DhavalTaunk08) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DhavalTaunk08/NLP_scripts/blob/master/sentiment_analysis_using_roberta.ipynb)|
|[Evaluating Question Generation Models](https://github.com/flexudy-pipe/qugeev) | How accurate are the answers to questions generated by your seq2seq transformer model? | [Pascal Zoleko](https://github.com/zolekode) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1bpsSqCQU-iw_5nNoRm_crPq6FRuJthq_?usp=sharing)|
|[Classify text with DistilBERT and Tensorflow](https://github.com/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb) | How to fine-tune DistilBERT for text classification in TensorFlow | [Peter Bayerle](https://github.com/peterbayerle) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/peterbayerle/huggingface_notebook/blob/main/distilbert_tf.ipynb)|
|[Leverage BERT for Encoder-Decoder Summarization on CNN/Dailymail](https://github.com/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb) | How to warm-start a *EncoderDecoderModel* with a *bert-base-uncased* checkpoint for summarization on CNN/Dailymail | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb)|
|[Leverage RoBERTa for Encoder-Decoder Summarization on BBC XSum](https://github.com/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb) | How to warm-start a shared *EncoderDecoderModel* with a *roberta-base* checkpoint for summarization on BBC/XSum | [Patrick von Platen](https://github.com/patrickvonplaten) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb)|
|[Fine-tuning TAPAS on Sequential Question Answering (SQA)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb) | How to fine-tune *TapasForQuestionAnswering* with a *tapas-base* checkpoint on the Sequential Question Answering (SQA) dataset | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQAipynb)|
|[Evaluating TAPAS on Table Fact Checking (TabFact)](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb) | How to evaluate a fine-tuned *TapasForSequenceClassification* with a *tapas-base-finetuned-tabfact* checkpoint using a combination of the 🤗 datasets and 🤗 transformers libraries | [Niels Rogge](https://github.com/nielsrogge) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Evaluating_TAPAS_on_the_Tabfact_test_set.ipynb)|

View File

@ -164,6 +164,7 @@ from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBert
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
TransfoXLConfig,
@ -605,6 +606,13 @@ if is_torch_available():
T5PreTrainedModel,
load_tf_weights_in_t5,
)
from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
)
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding,

View File

@ -216,6 +216,29 @@ except ImportError:
_tokenizers_available = False
try:
import pandas # noqa: F401
_pandas_available = True
except ImportError:
_pandas_available = False
try:
import torch_scatter
# Check we're not importing a "torch_scatter" directory somewhere
_scatter_available = hasattr(torch_scatter, "__version__") and hasattr(torch_scatter, "scatter")
if _scatter_available:
logger.debug(f"Succesfully imported torch-scatter version {torch_scatter.__version__}")
else:
logger.debug("Imported a torch_scatter object but this doesn't seem to be the torch-scatter library.")
except ImportError:
_scatter_available = False
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
@ -325,6 +348,14 @@ def is_in_notebook():
return _in_notebook
def is_scatter_available():
return _scatter_available
def is_pandas_available():
return _pandas_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
@ -427,6 +458,13 @@ installation page: https://github.com/google/flax and follow the ones that match
"""
# docstyle-ignore
SCATTER_IMPORT_ERROR = """
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
explained here: https://github.com/rusty1s/pytorch_scatter.
"""
def requires_datasets(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_datasets_available():
@ -481,6 +519,12 @@ def requires_protobuf(obj):
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
def requires_scatter(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_scatter_available():
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

View File

@ -51,6 +51,7 @@ from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCH
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import (
@ -95,6 +96,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
@ -141,6 +143,7 @@ CONFIG_MAPPING = OrderedDict(
("dpr", DPRConfig),
("layoutlm", LayoutLMConfig),
("rag", RagConfig),
("tapas", TapasConfig),
]
)
@ -185,6 +188,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("prophetnet", "ProphetNet"),
("mt5", "mT5"),
("mpnet", "MPNet"),
("tapas", "TAPAS"),
]
)

View File

@ -165,6 +165,12 @@ from ..squeezebert.modeling_squeezebert import (
SqueezeBertModel,
)
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from ..tapas.modeling_tapas import (
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
)
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..xlm.modeling_xlm import (
XLMForMultipleChoice,
@ -230,6 +236,7 @@ from .configuration_auto import (
RobertaConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
TransfoXLConfig,
XLMConfig,
XLMProphetNetConfig,
@ -277,6 +284,7 @@ MODEL_MAPPING = OrderedDict(
(XLMProphetNetConfig, XLMProphetNetModel),
(ProphetNetConfig, ProphetNetModel),
(MPNetConfig, MPNetModel),
(TapasConfig, TapasModel),
]
)
@ -308,6 +316,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(LxmertConfig, LxmertForPreTraining),
(FunnelConfig, FunnelForPreTraining),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
]
)
@ -340,6 +349,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(ReformerConfig, ReformerModelWithLMHead),
(FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
]
)
@ -386,6 +396,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(ReformerConfig, ReformerForMaskedLM),
(FunnelConfig, FunnelForMaskedLM),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
]
)
@ -431,6 +442,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(CTRLConfig, CTRLForSequenceClassification),
(TransfoXLConfig, TransfoXLForSequenceClassification),
(MPNetConfig, MPNetForSequenceClassification),
(TapasConfig, TapasForSequenceClassification),
]
)
@ -455,6 +467,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(FunnelConfig, FunnelForQuestionAnswering),
(LxmertConfig, LxmertForQuestionAnswering),
(MPNetConfig, MPNetForQuestionAnswering),
(TapasConfig, TapasForQuestionAnswering),
]
)

View File

@ -47,6 +47,7 @@ from ..rag.tokenization_rag import RagTokenizer
from ..retribert.tokenization_retribert import RetriBertTokenizer
from ..roberta.tokenization_roberta import RobertaTokenizer
from ..squeezebert.tokenization_squeezebert import SqueezeBertTokenizer
from ..tapas.tokenization_tapas import TapasTokenizer
from ..transfo_xl.tokenization_transfo_xl import TransfoXLTokenizer
from ..xlm.tokenization_xlm import XLMTokenizer
from .configuration_auto import (
@ -84,6 +85,7 @@ from .configuration_auto import (
RobertaConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
TransfoXLConfig,
XLMConfig,
XLMProphetNetConfig,
@ -223,6 +225,7 @@ TOKENIZER_MAPPING = OrderedDict(
(XLMProphetNetConfig, (XLMProphetNetTokenizer, None)),
(ProphetNetConfig, (ProphetNetTokenizer, None)),
(MPNetConfig, (MPNetTokenizer, MPNetTokenizerFast)),
(TapasConfig, (TapasTokenizer, None)),
]
)

View File

@ -0,0 +1,31 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...file_utils import is_torch_available
from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from .tokenization_tapas import TapasTokenizer
if is_torch_available():
from .modeling_tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
)

View File

@ -0,0 +1,219 @@
# coding=utf-8
# Copyright 2020 Google Research and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
TAPAS configuration. Based on the BERT configuration with added parameters.
Hyperparameters are taken from run_task_main.py and hparam_utils.py of the original implementation. URLS:
- https://github.com/google-research/tapas/blob/master/tapas/run_task_main.py
- https://github.com/google-research/tapas/blob/master/tapas/utils/hparam_utils.py
"""
from ...configuration_utils import PretrainedConfig
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/tapas-base-finetuned-sqa": "https://huggingface.co/google/tapas-base-finetuned-sqa/resolve/main/config.json",
"google/tapas-base-finetuned-wtq": "https://huggingface.co/google/tapas-base-finetuned-wtq/resolve/main/config.json",
"google/tapas-base-finetuned-wikisql-supervised": "https://huggingface.co/google/tapas-base-finetuned-wikisql-supervised/resolve/main/config.json",
"google/tapas-base-finetuned-tabfact": "https://huggingface.co/google/tapas-base-finetuned-tabfact/resolve/main/config.json",
}
class TapasConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.TapasModel`. It is used to
instantiate a TAPAS model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the TAPAS `tapas-base-finetuned-sqa`
architecture. Configuration objects inherit from :class:`~transformers.PreTrainedConfig` and can be used to control
the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Hyperparameters additional to BERT are taken from run_task_main.py and hparam_utils.py of the original
implementation. Original implementation available at https://github.com/google-research/tapas/tree/master.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the TAPAS model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.TapasModel`.
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_sizes (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 256, 256, 2, 256, 256, 10]`):
The vocabulary sizes of the :obj:`token_type_ids` passed when calling :class:`~transformers.TapasModel`.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use gradient checkpointing to save memory at the expense of a slower backward pass.
positive_label_weight (:obj:`float`, `optional`, defaults to 10.0):
Weight for positive labels.
num_aggregation_labels (:obj:`int`, `optional`, defaults to 0):
The number of aggregation operators to predict.
aggregation_loss_weight (:obj:`float`, `optional`, defaults to 1.0):
Importance weight for the aggregation loss.
use_answer_as_supervision (:obj:`bool`, `optional`):
Whether to use the answer as the only supervision for aggregation examples.
answer_loss_importance (:obj:`float`, `optional`, defaults to 1.0):
Importance weight for the regression loss.
use_normalized_answer_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to normalize the answer loss by the maximum of the predicted and expected value.
huber_loss_delta: (:obj:`float`, `optional`):
Delta parameter used to calculate the regression loss.
temperature: (:obj:`float`, `optional`, defaults to 1.0):
Value used to control (OR change) the skewness of cell logits probabilities.
aggregation_temperature: (:obj:`float`, `optional`, defaults to 1.0):
Scales aggregation logits to control the skewness of probabilities.
use_gumbel_for_cells: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to apply Gumbel-Softmax to cell selection.
use_gumbel_for_aggregation: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to apply Gumbel-Softmax to aggregation selection.
average_approximation_function: (:obj:`string`, `optional`, defaults to :obj:`"ratio"`):
Method to calculate the expected average of cells in the weak supervision case. One of :obj:`"ratio"`,
:obj:`"first_order"` or :obj:`"second_order"`.
cell_selection_preference: (:obj:`float`, `optional`):
Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
operator) is higher than this hyperparameter, then aggregation is predicted for an example.
answer_loss_cutoff: (:obj:`float`, `optional`):
Ignore examples with answer loss larger than cutoff.
max_num_rows: (:obj:`int`, `optional`, defaults to 64):
Maximum number of rows.
max_num_columns: (:obj:`int`, `optional`, defaults to 32):
Maximum number of columns.
average_logits_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to average logits per cell.
select_one_column: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to constrain the model to only select cells from a single column.
allow_empty_column_selection: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to allow not to select any column.
init_cell_selection_weights_to_zero: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
reset_position_index_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to restart position indexes at every cell (i.e. use relative position embeddings).
disable_per_token_loss: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to disable any (strong or weak) supervision on cells.
Example::
>>> from transformers import TapasModel, TapasConfig
>>> # Initializing a default (SQA) Tapas configuration
>>> configuration = TapasConfig()
>>> # Initializing a model from the configuration
>>> model = TapasModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "tapas"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_sizes=[3, 256, 256, 2, 256, 256, 10],
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
positive_label_weight=10.0,
num_aggregation_labels=0,
aggregation_loss_weight=1.0,
use_answer_as_supervision=None,
answer_loss_importance=1.0,
use_normalized_answer_loss=False,
huber_loss_delta=None,
temperature=1.0,
aggregation_temperature=1.0,
use_gumbel_for_cells=False,
use_gumbel_for_aggregation=False,
average_approximation_function="ratio",
cell_selection_preference=None,
answer_loss_cutoff=None,
max_num_rows=64,
max_num_columns=32,
average_logits_per_cell=False,
select_one_column=True,
allow_empty_column_selection=False,
init_cell_selection_weights_to_zero=False,
reset_position_index_per_cell=True,
disable_per_token_loss=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
# BERT hyperparameters (with updated max_position_embeddings and type_vocab_sizes)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_sizes = type_vocab_sizes
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
# Fine-tuning task hyperparameters
self.positive_label_weight = positive_label_weight
self.num_aggregation_labels = num_aggregation_labels
self.aggregation_loss_weight = aggregation_loss_weight
self.use_answer_as_supervision = use_answer_as_supervision
self.answer_loss_importance = answer_loss_importance
self.use_normalized_answer_loss = use_normalized_answer_loss
self.huber_loss_delta = huber_loss_delta
self.temperature = temperature
self.aggregation_temperature = aggregation_temperature
self.use_gumbel_for_cells = use_gumbel_for_cells
self.use_gumbel_for_aggregation = use_gumbel_for_aggregation
self.average_approximation_function = average_approximation_function
self.cell_selection_preference = cell_selection_preference
self.answer_loss_cutoff = answer_loss_cutoff
self.max_num_rows = max_num_rows
self.max_num_columns = max_num_columns
self.average_logits_per_cell = average_logits_per_cell
self.select_one_column = select_one_column
self.allow_empty_column_selection = allow_empty_column_selection
self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero
self.reset_position_index_per_cell = reset_position_index_per_cell
self.disable_per_token_loss = disable_per_token_loss

View File

@ -0,0 +1,137 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert TAPAS checkpoint."""
import argparse
from transformers.models.tapas.modeling_tapas import (
TapasConfig,
TapasForMaskedLM,
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
load_tf_weights_in_tapas,
)
from transformers.models.tapas.tokenization_tapas import TapasTokenizer
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(
task, reset_position_index_per_cell, tf_checkpoint_path, tapas_config_file, pytorch_dump_path
):
# Initialise PyTorch model.
# If you want to convert a checkpoint that uses absolute position embeddings, make sure to set reset_position_index_per_cell of
# TapasConfig to False.
# initialize configuration from json file
config = TapasConfig.from_json_file(tapas_config_file)
# set absolute/relative position embeddings parameter
config.reset_position_index_per_cell = reset_position_index_per_cell
# set remaining parameters of TapasConfig as well as the model based on the task
if task == "SQA":
model = TapasForQuestionAnswering(config=config)
elif task == "WTQ":
# run_task_main.py hparams
config.num_aggregation_labels = 4
config.use_answer_as_supervision = True
# hparam_utils.py hparams
config.answer_loss_cutoff = 0.664694
config.cell_selection_preference = 0.207951
config.huber_loss_delta = 0.121194
config.init_cell_selection_weights_to_zero = True
config.select_one_column = True
config.allow_empty_column_selection = False
config.temperature = 0.0352513
model = TapasForQuestionAnswering(config=config)
elif task == "WIKISQL_SUPERVISED":
# run_task_main.py hparams
config.num_aggregation_labels = 4
config.use_answer_as_supervision = False
# hparam_utils.py hparams
config.answer_loss_cutoff = 36.4519
config.cell_selection_preference = 0.903421
config.huber_loss_delta = 222.088
config.init_cell_selection_weights_to_zero = True
config.select_one_column = True
config.allow_empty_column_selection = True
config.temperature = 0.763141
model = TapasForQuestionAnswering(config=config)
elif task == "TABFACT":
model = TapasForSequenceClassification(config=config)
elif task == "MLM":
model = TapasForMaskedLM(config=config)
elif task == "INTERMEDIATE_PRETRAINING":
model = TapasModel(config=config)
print("Building PyTorch model from configuration: {}".format(str(config)))
# Load weights from tf checkpoint
load_tf_weights_in_tapas(model, config, tf_checkpoint_path)
# Save pytorch-model (weights and configuration)
print("Save PyTorch model to {}".format(pytorch_dump_path))
model.save_pretrained(pytorch_dump_path[:-17])
# Save tokenizer files
dir_name = r"C:\Users\niels.rogge\Documents\Python projecten\tensorflow\Tensorflow models\SQA\Base\tapas_sqa_inter_masklm_base_reset"
tokenizer = TapasTokenizer(vocab_file=dir_name + r"\vocab.txt", model_max_length=512)
print("Save tokenizer files to {}".format(pytorch_dump_path))
tokenizer.save_pretrained(pytorch_dump_path[:-17])
print("Used relative position embeddings:", model.config.reset_position_index_per_cell)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--task", default="SQA", type=str, help="Model task for which to convert a checkpoint. Defaults to SQA."
)
parser.add_argument(
"--reset_position_index_per_cell",
default=False,
action="store_true",
help="Whether to use relative position embeddings or not. Defaults to True.",
)
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--tapas_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained TAPAS model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
args.task,
args.reset_position_index_per_cell,
args.tf_checkpoint_path,
args.tapas_config_file,
args.pytorch_dump_path,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -28,6 +28,8 @@ from .file_utils import (
_datasets_available,
_faiss_available,
_flax_available,
_pandas_available,
_scatter_available,
_sentencepiece_available,
_tf_available,
_tokenizers_available,
@ -221,6 +223,27 @@ def require_tokenizers(test_case):
return test_case
def require_pandas(test_case):
"""
Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
"""
if not _pandas_available:
return unittest.skip("test requires pandas")(test_case)
else:
return test_case
def require_scatter(test_case):
"""
Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
installed.
"""
if not _scatter_available:
return unittest.skip("test requires PyTorch Scatter")(test_case)
else:
return test_case
def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch).

View File

@ -1867,6 +1867,45 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_pytorch(load_tf_weights_in_t5)
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TapasForMaskedLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class TapasForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class TapasForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class TapasModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None

1084
tests/test_modeling_tapas.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -584,7 +584,7 @@ class TokenizerTesterMixin:
# We want to have sequence 0 and sequence 1 are tagged
# respectively with 0 and 1 token_ids
# (regardeless of weither the model use token type ids)
# (regardless of whether the model use token type ids)
# We use this assumption in the QA pipeline among other place
output = tokenizer(seq_0, return_token_type_ids=True)
self.assertIn(0, output["token_type_ids"])
@ -600,7 +600,7 @@ class TokenizerTesterMixin:
# We want to have sequence 0 and sequence 1 are tagged
# respectively with 0 and 1 token_ids
# (regardeless of weither the model use token type ids)
# (regardless of whether the model use token type ids)
# We use this assumption in the QA pipeline among other place
output = tokenizer(seq_0)
self.assertIn(0, output.sequence_ids())

File diff suppressed because one or more lines are too long