[Longformer For Question Answering] Conversion script, doc, small fixes (#4593)

* add new longformer for question answering model

* add new config as well

* fix links

* fix links part 2
This commit is contained in:
Patrick von Platen 2020-05-26 14:58:47 +02:00 committed by GitHub
parent a163c9ca5b
commit c589eae2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 123 additions and 21 deletions

View File

@ -67,3 +67,10 @@ LongformerForMaskedLM
.. autoclass:: transformers.LongformerForMaskedLM
:members:
LongformerForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LongformerForQuestionAnswering
:members:

View File

@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
}

View File

@ -0,0 +1,86 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RoBERTa checkpoint."""
import argparse
import pytorch_lightning as pl
import torch
from transformers.modeling_longformer import LongformerForQuestionAnswering, LongformerModel
class LightningModel(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.num_labels = 2
self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels)
# implement only because lighning requires to do so
def forward(self):
pass
def convert_longformer_qa_checkpoint_to_pytorch(
longformer_model: str, longformer_question_answering_ckpt_path: str, pytorch_dump_folder_path: str
):
# load longformer model from model identifier
longformer = LongformerModel.from_pretrained(longformer_model)
lightning_model = LightningModel(longformer)
ckpt = torch.load(longformer_question_answering_ckpt_path, map_location=torch.device("cpu"))
lightning_model.load_state_dict(ckpt["state_dict"])
# init longformer question answering model
longformer_for_qa = LongformerForQuestionAnswering.from_pretrained(longformer_model)
# transfer weights
longformer_for_qa.longformer.load_state_dict(lightning_model.model.state_dict())
longformer_for_qa.qa_outputs.load_state_dict(lightning_model.qa_outputs.state_dict())
longformer_for_qa.eval()
# save model
longformer_for_qa.save_pretrained(pytorch_dump_folder_path)
print("Conversion succesful. Model saved under {}".format(pytorch_dump_folder_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--longformer_model",
default=None,
type=str,
required=True,
help="model identifier of longformer. Should be either `longformer-base-4096` or `longformer-large-4096`.",
)
parser.add_argument(
"--longformer_question_answering_ckpt_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch Lighning Checkpoint.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_longformer_qa_checkpoint_to_pytorch(
args.longformer_model, args.longformer_question_answering_ckpt_path, args.pytorch_dump_folder_path
)

View File

@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
}
@ -710,7 +711,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
@add_start_docstrings(
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
LONGFORMER_START_DOCSTRING,
)
@ -728,26 +729,27 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
self.init_weights()
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
assert sep_token_indices.size(1) == 2, "input_ids should have two dimensions"
assert sep_token_indices.size(0) == 3 * input_ids.size(
0
), "There should be exactly three separator tokens in every sample for questions answering"
return sep_token_indices.view(input_ids.size(0), 3, 2)[:, 0, 1]
def _compute_global_attention_mask(self, input_ids):
question_end_index = self._get_question_end_index(input_ids)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
# bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.size(1), device=input_ids.device)
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
attention_mask = attention_mask.int() + 1 # from True, False to 2, 1
attention_mask = attention_mask.int() + 1 # True => global attention; False => local attention
return attention_mask.long()
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
batch_size = input_ids.shape[0]
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
assert (
sep_token_indices.shape[0] == 3 * batch_size
), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering"
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
def forward(
self,
@ -769,7 +771,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
@ -785,24 +787,29 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForQuestionAnswering
import torch
tokenizer = LongformerTokenizer.from_pretrained(longformer-base-4096')
model = LongformerForQuestionAnswering.from_pretrained(longformer-base-4096')
tokenizer = LongformerTokenizer.from_pretrained("longformer-large-4096-finetuned-triviaqa")
model = LongformerForQuestionAnswering.from_pretrained("longformer-large-4096-finetuned-triviaqa")
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
encoding = tokenizer.encode_plus(question, text, return_tensors="pt")
input_ids = encoding["input_ids"]
# default is local attention everywhere
# the forward method will automatically set global attention on question tokens
attention_mask = encoding["attention_mask"]
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=attention_mask)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
start_scores, end_scores = model(input_ids, attention_mask=attention_mask)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1]
answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
"""
# set global attention on question tokens

View File

@ -24,12 +24,13 @@ logger = logging.getLogger(__name__)
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096"]
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096", "longformer-large-4096-finetuned-triviaqa"]
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"longformer-base-4096": 4096,
"longformer-large-4096": 4096,
"longformer-large-4096-finetuned-triviaqa": 4096,
}