mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
Question Answering support for Albert and Roberta in TF (#3812)
* Add TFAlbertForQuestionAnswering * Add TFRobertaForQuestionAnswering * Update TFAutoModel with Roberta/Albert for QA * Clean `super` TF Albert calls
This commit is contained in:
parent
f399c00610
commit
6d00033e97
@ -424,6 +424,7 @@ if is_tf_available():
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
@ -466,6 +467,7 @@ if is_tf_available():
|
||||
TFAlbertModel,
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForSequenceClassification,
|
||||
TFAlbertForQuestionAnswering,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
|
@ -721,7 +721,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
|
||||
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
||||
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs)
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
|
||||
@ -777,7 +777,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
|
||||
)
|
||||
class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs)
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
@ -826,3 +826,68 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.albert = TFAlbertMainLayer(config, name="albert")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
# The checkpoint albert-base-v2 is not fine-tuned for question answering. Please see the
|
||||
# examples/run_squad.py example to see how to fine-tune a model to a question answering task.
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import AlbertTokenizer, TFAlbertForQuestionAnswering
|
||||
|
||||
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
||||
model = TFAlbertForQuestionAnswering.from_pretrained('albert-base-v2')
|
||||
input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet")
|
||||
start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1
|
||||
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
outputs = self.albert(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||
|
||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
@ -36,6 +36,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .modeling_tf_albert import (
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForQuestionAnswering,
|
||||
TFAlbertForSequenceClassification,
|
||||
TFAlbertModel,
|
||||
)
|
||||
@ -62,6 +63,7 @@ from .modeling_tf_openai import TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, TFOp
|
||||
from .modeling_tf_roberta import (
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaModel,
|
||||
@ -172,6 +174,8 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
||||
[
|
||||
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||
(AlbertConfig, TFAlbertForQuestionAnswering),
|
||||
(RobertaConfig, TFRobertaForQuestionAnswering),
|
||||
(BertConfig, TFBertForQuestionAnswering),
|
||||
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
|
||||
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||
@ -827,6 +831,8 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
|
||||
- contains `albert`: TFAlbertForQuestionAnswering (ALBERT model)
|
||||
- contains `roberta`: TFRobertaForQuestionAnswering (RoBERTa model)
|
||||
- contains `bert`: TFBertForQuestionAnswering (Bert model)
|
||||
- contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
|
||||
- contains `xlm`: TFXLMForQuestionAnswering (XLM model)
|
||||
@ -849,6 +855,8 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
- isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model)
|
||||
- isInstance of `albert` configuration class: AlbertModel (ALBERT model)
|
||||
- isInstance of `roberta` configuration class: RobertaModel (RoBERTa model)
|
||||
- isInstance of `bert` configuration class: BertModel (Bert model)
|
||||
- isInstance of `xlnet` configuration class: XLNetModel (XLNet model)
|
||||
- isInstance of `xlm` configuration class: XLMModel (XLM model)
|
||||
@ -856,7 +864,7 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
Examples::
|
||||
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
model = TFAutoModelForQuestionAnswering.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
@ -882,6 +890,8 @@ class TFAutoModelForQuestionAnswering(object):
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
|
||||
- contains `albert`: TFAlbertForQuestionAnswering (ALBERT model)
|
||||
- contains `roberta`: TFRobertaForQuestionAnswering (RoBERTa model)
|
||||
- contains `bert`: TFBertForQuestionAnswering (Bert model)
|
||||
- contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
|
||||
- contains `xlm`: TFXLMForQuestionAnswering (XLM model)
|
||||
|
@ -442,3 +442,68 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel):
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = TFRobertaMainLayer(config, name="roberta")
|
||||
self.qa_outputs = tf.keras.layers.Dense(
|
||||
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING)
|
||||
def call(self, inputs, **kwargs):
|
||||
r"""
|
||||
Return:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
|
||||
start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-start scores (before SoftMax).
|
||||
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
tuple of :obj:`tf.Tensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
# The checkpoint roberta-base is not fine-tuned for question answering. Please see the
|
||||
# examples/run_squad.py example to see how to fine-tune a model to a question answering task.
|
||||
|
||||
import tensorflow as tf
|
||||
from transformers import RobertaTokenizer, TFRobertaForQuestionAnswering
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
||||
model = TFRobertaForQuestionAnswering.from_pretrained('roberta-base')
|
||||
input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet")
|
||||
start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1
|
||||
|
||||
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
||||
|
||||
"""
|
||||
outputs = self.roberta(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = tf.split(logits, 2, axis=-1)
|
||||
start_logits = tf.squeeze(start_logits, axis=-1)
|
||||
end_logits = tf.squeeze(end_logits, axis=-1)
|
||||
|
||||
outputs = (start_logits, end_logits,) + outputs[2:]
|
||||
|
||||
return outputs # start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
@ -28,6 +28,7 @@ if is_tf_available():
|
||||
TFAlbertModel,
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForSequenceClassification,
|
||||
TFAlbertForQuestionAnswering,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
@ -36,7 +37,9 @@ if is_tf_available():
|
||||
class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification) if is_tf_available() else ()
|
||||
(TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, TFAlbertForQuestionAnswering)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class TFAlbertModelTester(object):
|
||||
@ -175,6 +178,19 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
}
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
|
||||
|
||||
def create_and_check_albert_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFAlbertForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -208,6 +224,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_albert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
@ -31,6 +31,7 @@ if is_tf_available():
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
|
||||
@ -39,7 +40,15 @@ if is_tf_available():
|
||||
class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification) if is_tf_available() else ()
|
||||
(
|
||||
TFRobertaModel,
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaForQuestionAnswering,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
class TFRobertaModelTester(object):
|
||||
@ -171,6 +180,19 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
|
||||
)
|
||||
|
||||
def create_and_check_roberta_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = TFRobertaForQuestionAnswering(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
start_logits, end_logits = model(inputs)
|
||||
result = {
|
||||
"start_logits": start_logits.numpy(),
|
||||
"end_logits": end_logits.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@ -204,6 +226,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
Loading…
Reference in New Issue
Block a user