mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
finish reformer qa head (#5433)
This commit is contained in:
parent
d697b6ca75
commit
fe81f7d12c
@ -112,3 +112,11 @@ ReformerModelWithLMHead
|
|||||||
|
|
||||||
.. autoclass:: transformers.ReformerModelWithLMHead
|
.. autoclass:: transformers.ReformerModelWithLMHead
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
ReformerForQuestionAnswering
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.ReformerForQuestionAnswering
|
||||||
|
:members:
|
||||||
|
|
||||||
|
@ -367,6 +367,7 @@ if is_torch_available():
|
|||||||
ReformerLayer,
|
ReformerLayer,
|
||||||
ReformerModel,
|
ReformerModel,
|
||||||
ReformerModelWithLMHead,
|
ReformerModelWithLMHead,
|
||||||
|
ReformerForQuestionAnswering,
|
||||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ from .modeling_mobilebert import (
|
|||||||
MobileBertModel,
|
MobileBertModel,
|
||||||
)
|
)
|
||||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
from .modeling_reformer import ReformerForQuestionAnswering, ReformerModel, ReformerModelWithLMHead
|
||||||
from .modeling_retribert import RetriBertModel
|
from .modeling_retribert import RetriBertModel
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
@ -310,6 +310,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||||||
(MobileBertConfig, MobileBertForQuestionAnswering),
|
(MobileBertConfig, MobileBertForQuestionAnswering),
|
||||||
(XLMConfig, XLMForQuestionAnsweringSimple),
|
(XLMConfig, XLMForQuestionAnsweringSimple),
|
||||||
(ElectraConfig, ElectraForQuestionAnswering),
|
(ElectraConfig, ElectraForQuestionAnswering),
|
||||||
|
(ReformerConfig, ReformerForQuestionAnswering),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1789,3 +1789,109 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
|||||||
inputs_dict["num_hashes"] = kwargs["num_hashes"]
|
inputs_dict["num_hashes"] = kwargs["num_hashes"]
|
||||||
|
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Reformer Model with a span classification head on top for
|
||||||
|
extractive question-answering tasks like SQuAD / TriviaQA ( a linear layer on
|
||||||
|
top of hidden-states output to compute `span start logits` and `span end logits`. """,
|
||||||
|
REFORMER_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class ReformerForQuestionAnswering(ReformerPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
|
||||||
|
self.reformer = ReformerModel(config)
|
||||||
|
# 2 * config.hidden_size because we use reversible residual layers
|
||||||
|
self.qa_outputs = nn.Linear(2 * config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
# word embeddings are not tied in Reformer
|
||||||
|
pass
|
||||||
|
|
||||||
|
@add_start_docstrings_to_callable(REFORMER_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/reformer-crime-and-punishment")
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
num_hashes=None,
|
||||||
|
start_positions=None,
|
||||||
|
end_positions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
output_attentions=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
Return:
|
||||||
|
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ReformerConfig`) and inputs:
|
||||||
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||||
|
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||||
|
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||||
|
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reformer_outputs = self.reformer(
|
||||||
|
input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
num_hashes=num_hashes,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = reformer_outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1)
|
||||||
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
outputs = (start_logits, end_logits,) + reformer_outputs[1:]
|
||||||
|
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions.clamp_(0, ignored_index)
|
||||||
|
end_positions.clamp_(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
outputs = (total_loss,) + outputs
|
||||||
|
|
||||||
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
@ -29,6 +29,7 @@ if is_torch_available():
|
|||||||
ReformerModelWithLMHead,
|
ReformerModelWithLMHead,
|
||||||
ReformerTokenizer,
|
ReformerTokenizer,
|
||||||
ReformerLayer,
|
ReformerLayer,
|
||||||
|
ReformerForQuestionAnswering,
|
||||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
)
|
)
|
||||||
import torch
|
import torch
|
||||||
@ -43,6 +44,7 @@ class ReformerModelTester:
|
|||||||
is_training=None,
|
is_training=None,
|
||||||
is_decoder=None,
|
is_decoder=None,
|
||||||
use_input_mask=None,
|
use_input_mask=None,
|
||||||
|
use_labels=None,
|
||||||
vocab_size=None,
|
vocab_size=None,
|
||||||
attention_head_size=None,
|
attention_head_size=None,
|
||||||
hidden_size=None,
|
hidden_size=None,
|
||||||
@ -81,6 +83,7 @@ class ReformerModelTester:
|
|||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.is_decoder = is_decoder
|
self.is_decoder = is_decoder
|
||||||
self.use_input_mask = use_input_mask
|
self.use_input_mask = use_input_mask
|
||||||
|
self.use_labels = use_labels
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.attention_head_size = attention_head_size
|
self.attention_head_size = attention_head_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -128,6 +131,10 @@ class ReformerModelTester:
|
|||||||
if self.use_input_mask:
|
if self.use_input_mask:
|
||||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
choice_labels = None
|
||||||
|
if self.use_labels:
|
||||||
|
choice_labels = ids_tensor([self.batch_size], 2)
|
||||||
|
|
||||||
config = ReformerConfig(
|
config = ReformerConfig(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@ -160,14 +167,13 @@ class ReformerModelTester:
|
|||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
|
choice_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_loss_output(self, result):
|
def check_loss_output(self, result):
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
|
|
||||||
def create_and_check_reformer_model(
|
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
|
||||||
self, config, input_ids, input_mask,
|
|
||||||
):
|
|
||||||
model = ReformerModel(config=config)
|
model = ReformerModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -182,18 +188,14 @@ class ReformerModelTester:
|
|||||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_reformer_model_with_lm_backward(
|
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||||
self, config, input_ids, input_mask,
|
|
||||||
):
|
|
||||||
model = ReformerModelWithLMHead(config=config)
|
model = ReformerModelWithLMHead(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def create_and_check_reformer_with_lm(
|
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
|
||||||
self, config, input_ids, input_mask,
|
|
||||||
):
|
|
||||||
model = ReformerModelWithLMHead(config=config)
|
model = ReformerModelWithLMHead(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -207,7 +209,7 @@ class ReformerModelTester:
|
|||||||
)
|
)
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, is_decoder):
|
def create_and_check_reformer_model_with_attn_mask(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||||
# no special position embeddings
|
# no special position embeddings
|
||||||
config.axial_pos_embds = False
|
config.axial_pos_embds = False
|
||||||
config.is_decoder = is_decoder
|
config.is_decoder = is_decoder
|
||||||
@ -248,7 +250,7 @@ class ReformerModelTester:
|
|||||||
|
|
||||||
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
|
self.parent.assertTrue(torch.allclose(output_padded, output_padded_rolled, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, is_decoder):
|
def create_and_check_reformer_layer_dropout_seed(self, config, input_ids, input_mask, choice_labels, is_decoder):
|
||||||
config.is_decoder = is_decoder
|
config.is_decoder = is_decoder
|
||||||
layer = ReformerLayer(config).to(torch_device)
|
layer = ReformerLayer(config).to(torch_device)
|
||||||
layer.train()
|
layer.train()
|
||||||
@ -281,7 +283,7 @@ class ReformerModelTester:
|
|||||||
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
torch.allclose(next_hidden_states, hidden_states + feed_forward_hidden_states, atol=1e-3,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask):
|
def create_and_check_reformer_feed_forward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
model = ReformerModel(config=config)
|
model = ReformerModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@ -299,7 +301,7 @@ class ReformerModelTester:
|
|||||||
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
|
hidden_states_with_chunk = model(input_ids, attention_mask=input_mask)[0]
|
||||||
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
self.parent.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask):
|
def create_and_check_reformer_feed_backward_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||||
if not self.is_training:
|
if not self.is_training:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -341,7 +343,7 @@ class ReformerModelTester:
|
|||||||
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
|
torch.allclose(grad_slice_position_factor_2_chunk, grad_slice_position_factor_2_no_chunk, atol=1e-3)
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask):
|
def create_and_check_reformer_random_seed(self, config, input_ids, input_mask, choice_labels):
|
||||||
layer = ReformerLayer(config).to(torch_device)
|
layer = ReformerLayer(config).to(torch_device)
|
||||||
layer.train()
|
layer.train()
|
||||||
|
|
||||||
@ -372,7 +374,7 @@ class ReformerModelTester:
|
|||||||
seeds.append(layer.feed_forward_seed)
|
seeds.append(layer.feed_forward_seed)
|
||||||
self.parent.assertGreater(len(set(seeds)), 70)
|
self.parent.assertGreater(len(set(seeds)), 70)
|
||||||
|
|
||||||
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask):
|
def create_and_check_reformer_model_fp16_forward(self, config, input_ids, input_mask, choice_labels):
|
||||||
model = ReformerModel(config=config)
|
model = ReformerModel(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.half()
|
model.half()
|
||||||
@ -380,7 +382,7 @@ class ReformerModelTester:
|
|||||||
output = model(input_ids, attention_mask=input_mask)[0]
|
output = model(input_ids, attention_mask=input_mask)[0]
|
||||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask):
|
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
|
||||||
model = ReformerModelWithLMHead(config=config)
|
model = ReformerModelWithLMHead(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.half()
|
model.half()
|
||||||
@ -388,7 +390,7 @@ class ReformerModelTester:
|
|||||||
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
output = model.generate(input_ids, attention_mask=input_mask, do_sample=False)
|
||||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||||
|
|
||||||
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask):
|
def create_and_check_reformer_no_chunking(self, config, input_ids, input_mask, choice_labels):
|
||||||
# force chunk length to be bigger than input_ids
|
# force chunk length to be bigger than input_ids
|
||||||
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||||
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||||
@ -398,9 +400,25 @@ class ReformerModelTester:
|
|||||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||||
|
|
||||||
|
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||||
|
model = ReformerForQuestionAnswering(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
loss, start_logits, end_logits = model(
|
||||||
|
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
|
||||||
|
)
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"start_logits": start_logits,
|
||||||
|
"end_logits": end_logits,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||||
|
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, input_mask,) = config_and_inputs
|
(config, input_ids, input_mask, choice_labels) = config_and_inputs
|
||||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@ -470,7 +488,9 @@ class ReformerTesterMixin:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
@ -483,6 +503,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
|||||||
"is_training": True,
|
"is_training": True,
|
||||||
"is_decoder": False,
|
"is_decoder": False,
|
||||||
"use_input_mask": True,
|
"use_input_mask": True,
|
||||||
|
"use_labels": True,
|
||||||
"vocab_size": 32,
|
"vocab_size": 32,
|
||||||
"attention_head_size": 16,
|
"attention_head_size": 16,
|
||||||
"hidden_size": 32,
|
"hidden_size": 32,
|
||||||
@ -524,7 +545,9 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
all_model_classes = (
|
||||||
|
(ReformerModel, ReformerModelWithLMHead, ReformerForQuestionAnswering) if is_torch_available() else ()
|
||||||
|
)
|
||||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
@ -535,6 +558,7 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
|||||||
"batch_size": 13,
|
"batch_size": 13,
|
||||||
"seq_length": 13,
|
"seq_length": 13,
|
||||||
"use_input_mask": True,
|
"use_input_mask": True,
|
||||||
|
"use_labels": True,
|
||||||
"is_training": False,
|
"is_training": False,
|
||||||
"is_decoder": False,
|
"is_decoder": False,
|
||||||
"vocab_size": 32,
|
"vocab_size": 32,
|
||||||
|
Loading…
Reference in New Issue
Block a user