diff --git a/docs/source/model_doc/reformer.rst b/docs/source/model_doc/reformer.rst index 016c8937a38..40539c3d68d 100644 --- a/docs/source/model_doc/reformer.rst +++ b/docs/source/model_doc/reformer.rst @@ -112,3 +112,11 @@ ReformerModelWithLMHead .. autoclass:: transformers.ReformerModelWithLMHead :members: + + +ReformerForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.ReformerForQuestionAnswering + :members: + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4447d5443dc..763ba157292 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -367,6 +367,7 @@ if is_torch_available(): ReformerLayer, ReformerModel, ReformerModelWithLMHead, + ReformerForQuestionAnswering, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index fff71ef6862..c33c00b38ef 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -122,7 +122,7 @@ from .modeling_mobilebert import ( MobileBertModel, ) from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel -from .modeling_reformer import ReformerModel, ReformerModelWithLMHead +from .modeling_reformer import ReformerForQuestionAnswering, ReformerModel, ReformerModelWithLMHead from .modeling_retribert import RetriBertModel from .modeling_roberta import ( RobertaForMaskedLM, @@ -310,6 +310,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( (MobileBertConfig, MobileBertForQuestionAnswering), (XLMConfig, XLMForQuestionAnsweringSimple), (ElectraConfig, ElectraForQuestionAnswering), + (ReformerConfig, ReformerForQuestionAnswering), ] ) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index c31da90578f..8190c14fbf9 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -1789,3 +1789,109 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): inputs_dict["num_hashes"] = kwargs["num_hashes"] return inputs_dict + + +@add_start_docstrings( + """Reformer Model with a span classification head on top for + extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on + top of hidden-states output to compute `span start logits` and `span end logits`. """, + REFORMER_START_DOCSTRING, +) +class ReformerForQuestionAnswering(ReformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.reformer = ReformerModel(config) + # 2 * config.hidden_size because we use reversible residual layers + self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels) + + self.init_weights() + + def tie_weights(self): + # word embeddings are not tied in Reformer + pass + + @add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment") + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + num_hashes=None, + start_positions=None, + end_positions=None, + output_hidden_states=None, + output_attentions=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + reformer_outputs = self.reformer( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + num_hashes=num_hashes, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + + sequence_output = reformer_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + reformer_outputs[1:] + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index 1a3b9ad4c1d..e3a2667dd68 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -29,6 +29,7 @@ if is_torch_available(): ReformerModelWithLMHead, ReformerTokenizer, ReformerLayer, + ReformerForQuestionAnswering, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ) import torch @@ -43,6 +44,7 @@ class ReformerModelTester: is_training=None, is_decoder=None, use_input_mask=None, + use_labels=None, vocab_size=None, attention_head_size=None, hidden_size=None, @@ -81,6 +83,7 @@ class ReformerModelTester: self.is_training = is_training self.is_decoder = is_decoder self.use_input_mask = use_input_mask + self.use_labels = use_labels self.vocab_size = vocab_size self.attention_head_size = attention_head_size self.hidden_size = hidden_size @@ -128,6 +131,10 @@ class ReformerModelTester: if self.use_input_mask: input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + choice_labels = None + if self.use_labels: + choice_labels = ids_tensor([self.batch_size], 2) + config = ReformerConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, @@ -160,14 +167,13 @@ class ReformerModelTester: config, input_ids, input_mask, + choice_labels, ) def check_loss_output(self, result): self.parent.assertListEqual(list(result["loss"].size()), []) - def create_and_check_reformer_model( - self, config, input_ids, input_mask, - ): + def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels): model = ReformerModel(config=config) model.to(torch_device) model.eval() @@ -182,18 +188,14 @@ class ReformerModelTester: list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size], ) - def create_and_check_reformer_model_with_lm_backward( - self, config, input_ids, input_mask, - ): + def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] loss.backward() - def create_and_check_reformer_with_lm( - self, config, input_ids, input_mask, - ): + def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() @@ -207,7 +209,7 @@ class ReformerModelTester: ) self.check_loss_output(result) - def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder): # no special position embeddings config.axial_pos_embds = False config.is_decoder = is_decoder @@ -248,7 +250,7 @@ class ReformerModelTester: self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3)) - def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder): + def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder): config.is_decoder = is_decoder layer = ReformerLayer(config).to(torch_device) layer.train() @@ -281,7 +283,7 @@ class ReformerModelTester: torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,) ) - def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask): + def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels): torch.manual_seed(0) model = ReformerModel(config=config) model.to(torch_device) @@ -299,7 +301,7 @@ class ReformerModelTester: hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0] self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3)) - def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask): + def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels): if not self.is_training: return @@ -341,7 +343,7 @@ class ReformerModelTester: torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3) ) - def create_and_check_reformer_random_seed(self, config, input_ids, input_mask): + def create_and_check_reformer_random_seed(self, config, input_ids, input_mask, choice_labels): layer = ReformerLayer(config).to(torch_device) layer.train() @@ -372,7 +374,7 @@ class ReformerModelTester: seeds.append(layer.feed_forward_seed) self.parent.assertGreater(len(set(seeds)), 70) - def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask): + def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels): model = ReformerModel(config=config) model.to(torch_device) model.half() @@ -380,7 +382,7 @@ class ReformerModelTester: output = model(input_ids, attention_mask=input_mask)[0] self.parent.assertFalse(torch.isnan(output).any().item()) - def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask): + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels): model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() @@ -388,7 +390,7 @@ class ReformerModelTester: output = model.generate(input_ids, attention_mask=input_mask, do_sample=False) self.parent.assertFalse(torch.isnan(output).any().item()) - def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask): + def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels): # force chunk length to be bigger than input_ids config.lsh_attn_chunk_length = 2 * input_ids.shape[-1] config.local_attn_chunk_length = 2 * input_ids.shape[-1] @@ -398,9 +400,25 @@ class ReformerModelTester: output_logits = model(input_ids, attention_mask=input_mask)[0] self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) + def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): + model = ReformerForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + loss, start_logits, end_logits = model( + input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels, + ) + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() - (config, input_ids, input_mask,) = config_and_inputs + (config, input_ids, input_mask, choice_labels) = config_and_inputs inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict @@ -470,7 +488,9 @@ class ReformerTesterMixin: @require_torch class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_model_classes = ( + (ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else () + ) all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False test_headmasking = False @@ -483,6 +503,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest "is_training": True, "is_decoder": False, "use_input_mask": True, + "use_labels": True, "vocab_size": 32, "attention_head_size": 16, "hidden_size": 32, @@ -524,7 +545,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest @require_torch class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): - all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () + all_model_classes = ( + (ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else () + ) all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False test_headmasking = False @@ -535,6 +558,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T "batch_size": 13, "seq_length": 13, "use_input_mask": True, + "use_labels": True, "is_training": False, "is_decoder": False, "vocab_size": 32,