mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
finish reformer qa head (#5433)
This commit is contained in:
parent
d697b6ca75
commit
fe81f7d12c
@ -112,3 +112,11 @@ ReformerModelWithLMHead
|
||||
|
||||
.. autoclass:: transformers.ReformerModelWithLMHead
|
||||
:members:
|
||||
|
||||
|
||||
ReformerForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.ReformerForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
@ -367,6 +367,7 @@ if is_torch_available():
|
||||
ReformerLayer,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
ReformerForQuestionAnswering,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
|
@ -122,7 +122,7 @@ from .modeling_mobilebert import (
|
||||
MobileBertModel,
|
||||
)
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
||||
from .modeling_reformer import ReformerForQuestionAnswering, ReformerModel, ReformerModelWithLMHead
|
||||
from .modeling_retribert import RetriBertModel
|
||||
from .modeling_roberta import (
|
||||
RobertaForMaskedLM,
|
||||
@ -310,6 +310,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
(MobileBertConfig, MobileBertForQuestionAnswering),
|
||||
(XLMConfig, XLMForQuestionAnsweringSimple),
|
||||
(ElectraConfig, ElectraForQuestionAnswering),
|
||||
(ReformerConfig, ReformerForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1789,3 +1789,109 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
inputs_dict["num_hashes"] = kwargs["num_hashes"]
|
||||
|
||||
return inputs_dict
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Reformer Model with a span classification head on top for
|
||||
extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on
|
||||
top of hidden-states output to compute `span start logits` and `span end logits`. """,
|
||||
REFORMER_START_DOCSTRING,
|
||||
)
|
||||
class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.reformer = ReformerModel(config)
|
||||
# 2 * config.hidden_size because we use reversible residual layers
|
||||
self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
# word embeddings are not tied in Reformer
|
||||
pass
|
||||
|
||||
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
num_hashes=None,
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
output_hidden_states=None,
|
||||
output_attentions=None,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
reformer_outputs = self.reformer(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
num_hashes=num_hashes,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
sequence_output = reformer_outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + reformer_outputs[1:]
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
@ -29,6 +29,7 @@ if is_torch_available():
|
||||
ReformerModelWithLMHead,
|
||||
ReformerTokenizer,
|
||||
ReformerLayer,
|
||||
ReformerForQuestionAnswering,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
import torch
|
||||
@ -43,6 +44,7 @@ class ReformerModelTester:
|
||||
is_training=None,
|
||||
is_decoder=None,
|
||||
use_input_mask=None,
|
||||
use_labels=None,
|
||||
vocab_size=None,
|
||||
attention_head_size=None,
|
||||
hidden_size=None,
|
||||
@ -81,6 +83,7 @@ class ReformerModelTester:
|
||||
self.is_training = is_training
|
||||
self.is_decoder = is_decoder
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.attention_head_size = attention_head_size
|
||||
self.hidden_size = hidden_size
|
||||
@ -128,6 +131,10 @@ class ReformerModelTester:
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
choice_labels = ids_tensor([self.batch_size], 2)
|
||||
|
||||
config = ReformerConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
@ -160,14 +167,13 @@ class ReformerModelTester:
|
||||
config,
|
||||
input_ids,
|
||||
input_mask,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
|
||||
def create_and_check_reformer_model(
|
||||
self, config, input_ids, input_mask,
|
||||
):
|
||||
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -182,18 +188,14 @@ class ReformerModelTester:
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(
|
||||
self, config, input_ids, input_mask,
|
||||
):
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_reformer_with_lm(
|
||||
self, config, input_ids, input_mask,
|
||||
):
|
||||
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@ -207,7 +209,7 @@ class ReformerModelTester:
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder):
|
||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
# no special position embeddings
|
||||
config.axial_pos_embds = False
|
||||
config.is_decoder = is_decoder
|
||||
@ -248,7 +250,7 @@ class ReformerModelTester:
|
||||
|
||||
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder):
|
||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||
config.is_decoder = is_decoder
|
||||
layer = ReformerLayer(config).to(torch_device)
|
||||
layer.train()
|
||||
@ -281,7 +283,7 @@ class ReformerModelTester:
|
||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
@ -299,7 +301,7 @@ class ReformerModelTester:
|
||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
if not self.is_training:
|
||||
return
|
||||
|
||||
@ -341,7 +343,7 @@ class ReformerModelTester:
|
||||
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
|
||||
)
|
||||
|
||||
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask, choice_labels):
|
||||
layer = ReformerLayer(config).to(torch_device)
|
||||
layer.train()
|
||||
|
||||
@ -372,7 +374,7 @@ class ReformerModelTester:
|
||||
seeds.append(layer.feed_forward_seed)
|
||||
self.parent.assertGreater(len(set(seeds)), 70)
|
||||
|
||||
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
@ -380,7 +382,7 @@ class ReformerModelTester:
|
||||
output = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
@ -388,7 +390,7 @@ class ReformerModelTester:
|
||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
|
||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||
# force chunk length to be bigger than input_ids
|
||||
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
@ -398,9 +400,25 @@ class ReformerModelTester:
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask,) = config_and_inputs
|
||||
(config, input_ids, input_mask, choice_labels) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
@ -470,7 +488,9 @@ class ReformerTesterMixin:
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@ -483,6 +503,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
"is_training": True,
|
||||
"is_decoder": False,
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"vocab_size": 32,
|
||||
"attention_head_size": 16,
|
||||
"hidden_size": 32,
|
||||
@ -524,7 +545,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
|
||||
)
|
||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@ -535,6 +558,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
"batch_size": 13,
|
||||
"seq_length": 13,
|
||||
"use_input_mask": True,
|
||||
"use_labels": True,
|
||||
"is_training": False,
|
||||
"is_decoder": False,
|
||||
"vocab_size": 32,
|
||||
|
Loading…
Reference in New Issue
Block a user