mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[Longformer For Question Answering] Conversion script, doc, small fixes (#4593)
* add new longformer for question answering model * add new config as well * fix links * fix links part 2
This commit is contained in:
parent
a163c9ca5b
commit
c589eae2b8
@ -67,3 +67,10 @@ LongformerForMaskedLM
|
||||
|
||||
.. autoclass:: transformers.LongformerForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
LongformerForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LongformerForQuestionAnswering
|
||||
:members:
|
||||
|
@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
|
||||
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
|
||||
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
|
||||
}
|
||||
|
||||
|
||||
|
@ -0,0 +1,86 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert RoBERTa checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from transformers.modeling_longformer import LongformerForQuestionAnswering, LongformerModel
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.num_labels = 2
|
||||
self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels)
|
||||
|
||||
# implement only because lighning requires to do so
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
|
||||
def convert_longformer_qa_checkpoint_to_pytorch(
|
||||
longformer_model: str, longformer_question_answering_ckpt_path: str, pytorch_dump_folder_path: str
|
||||
):
|
||||
|
||||
# load longformer model from model identifier
|
||||
longformer = LongformerModel.from_pretrained(longformer_model)
|
||||
lightning_model = LightningModel(longformer)
|
||||
|
||||
ckpt = torch.load(longformer_question_answering_ckpt_path, map_location=torch.device("cpu"))
|
||||
lightning_model.load_state_dict(ckpt["state_dict"])
|
||||
|
||||
# init longformer question answering model
|
||||
longformer_for_qa = LongformerForQuestionAnswering.from_pretrained(longformer_model)
|
||||
|
||||
# transfer weights
|
||||
longformer_for_qa.longformer.load_state_dict(lightning_model.model.state_dict())
|
||||
longformer_for_qa.qa_outputs.load_state_dict(lightning_model.qa_outputs.state_dict())
|
||||
longformer_for_qa.eval()
|
||||
|
||||
# save model
|
||||
longformer_for_qa.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
print("Conversion succesful. Model saved under {}".format(pytorch_dump_folder_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--longformer_model",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="model identifier of longformer. Should be either `longformer-base-4096` or `longformer-large-4096`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--longformer_question_answering_ckpt_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch Lighning Checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_longformer_qa_checkpoint_to_pytorch(
|
||||
args.longformer_model, args.longformer_question_answering_ckpt_path, args.pytorch_dump_folder_path
|
||||
)
|
@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
|
||||
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
|
||||
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
@ -710,7 +711,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
"""Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
LONGFORMER_START_DOCSTRING,
|
||||
)
|
||||
@ -728,26 +729,27 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _get_question_end_index(self, input_ids):
|
||||
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
|
||||
|
||||
assert sep_token_indices.size(1) == 2, "input_ids should have two dimensions"
|
||||
assert sep_token_indices.size(0) == 3 * input_ids.size(
|
||||
0
|
||||
), "There should be exactly three separator tokens in every sample for questions answering"
|
||||
|
||||
return sep_token_indices.view(input_ids.size(0), 3, 2)[:, 0, 1]
|
||||
|
||||
def _compute_global_attention_mask(self, input_ids):
|
||||
question_end_index = self._get_question_end_index(input_ids)
|
||||
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
|
||||
# bool attention mask with True in locations of global attention
|
||||
attention_mask = torch.arange(input_ids.size(1), device=input_ids.device)
|
||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
|
||||
|
||||
attention_mask = attention_mask.int() + 1 # from True, False to 2, 1
|
||||
attention_mask = attention_mask.int() + 1 # True => global attention; False => local attention
|
||||
return attention_mask.long()
|
||||
|
||||
def _get_question_end_index(self, input_ids):
|
||||
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
|
||||
assert (
|
||||
sep_token_indices.shape[0] == 3 * batch_size
|
||||
), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering"
|
||||
|
||||
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
@ -769,7 +771,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
@ -785,24 +787,29 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import LongformerTokenizer, LongformerForQuestionAnswering
|
||||
import torch
|
||||
|
||||
tokenizer = LongformerTokenizer.from_pretrained(longformer-base-4096')
|
||||
model = LongformerForQuestionAnswering.from_pretrained(longformer-base-4096')
|
||||
tokenizer = LongformerTokenizer.from_pretrained("longformer-large-4096-finetuned-triviaqa")
|
||||
model = LongformerForQuestionAnswering.from_pretrained("longformer-large-4096-finetuned-triviaqa")
|
||||
|
||||
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
||||
encoding = tokenizer.encode_plus(question, text)
|
||||
encoding = tokenizer.encode_plus(question, text, return_tensors="pt")
|
||||
input_ids = encoding["input_ids"]
|
||||
|
||||
# default is local attention everywhere
|
||||
# the forward method will automatically set global attention on question tokens
|
||||
attention_mask = encoding["attention_mask"]
|
||||
|
||||
start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=attention_mask)
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
|
||||
start_scores, end_scores = model(input_ids, attention_mask=attention_mask)
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
|
||||
|
||||
answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1]
|
||||
answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
|
||||
|
||||
"""
|
||||
|
||||
# set global attention on question tokens
|
||||
|
@ -24,12 +24,13 @@ logger = logging.getLogger(__name__)
|
||||
# vocab and merges same as roberta
|
||||
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096"]
|
||||
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096", "longformer-large-4096-finetuned-triviaqa"]
|
||||
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"longformer-base-4096": 4096,
|
||||
"longformer-large-4096": 4096,
|
||||
"longformer-large-4096-finetuned-triviaqa": 4096,
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user