diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 01f56001645..547e1d3afaf 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -424,6 +424,7 @@ if is_tf_available(): TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TFRobertaForTokenClassification, + TFRobertaForQuestionAnswering, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ) @@ -466,6 +467,7 @@ if is_tf_available(): TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, + TFAlbertForQuestionAnswering, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ) diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 3473c92f819..6dc03af4ead 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -721,7 +721,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): @add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING) class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): def __init__(self, config, *inputs, **kwargs): - super(TFAlbertForMaskedLM, self).__init__(config, *inputs, **kwargs) + super().__init__(config, *inputs, **kwargs) self.albert = TFAlbertMainLayer(config, name="albert") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") @@ -777,7 +777,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): ) class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel): def __init__(self, config, *inputs, **kwargs): - super(TFAlbertForSequenceClassification, self).__init__(config, *inputs, **kwargs) + super().__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels self.albert = TFAlbertMainLayer(config, name="albert") @@ -826,3 +826,68 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel): outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here return outputs # logits, (hidden_states), (attentions) + + +@add_start_docstrings( + """Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, + ALBERT_START_DOCSTRING, +) +class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.albert = TFAlbertMainLayer(config, name="albert") + self.qa_outputs = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + + @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) + def call(self, inputs, **kwargs): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + # The checkpoint albert-base-v2 is not fine-tuned for question answering. Please see the + # examples/run_squad.py example to see how to fine-tune a model to a question answering task. + + import tensorflow as tf + from transformers import AlbertTokenizer, TFAlbertForQuestionAnswering + + tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') + model = TFAlbertForQuestionAnswering.from_pretrained('albert-base-v2') + input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet") + start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1 + + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) + + """ + outputs = self.albert(inputs, **kwargs) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + + return outputs # start_logits, end_logits, (hidden_states), (attentions) diff --git a/src/transformers/modeling_tf_auto.py b/src/transformers/modeling_tf_auto.py index 8ea0e6d8ac3..2804368b337 100644 --- a/src/transformers/modeling_tf_auto.py +++ b/src/transformers/modeling_tf_auto.py @@ -36,6 +36,7 @@ from .configuration_utils import PretrainedConfig from .modeling_tf_albert import ( TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, TFAlbertForMaskedLM, + TFAlbertForQuestionAnswering, TFAlbertForSequenceClassification, TFAlbertModel, ) @@ -62,6 +63,7 @@ from .modeling_tf_openai import TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, TFOp from .modeling_tf_roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, TFRobertaForMaskedLM, + TFRobertaForQuestionAnswering, TFRobertaForSequenceClassification, TFRobertaForTokenClassification, TFRobertaModel, @@ -172,6 +174,8 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( [ (DistilBertConfig, TFDistilBertForQuestionAnswering), + (AlbertConfig, TFAlbertForQuestionAnswering), + (RobertaConfig, TFRobertaForQuestionAnswering), (BertConfig, TFBertForQuestionAnswering), (XLNetConfig, TFXLNetForQuestionAnsweringSimple), (XLMConfig, TFXLMForQuestionAnsweringSimple), @@ -827,6 +831,8 @@ class TFAutoModelForQuestionAnswering(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): - contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model) + - contains `albert`: TFAlbertForQuestionAnswering (ALBERT model) + - contains `roberta`: TFRobertaForQuestionAnswering (RoBERTa model) - contains `bert`: TFBertForQuestionAnswering (Bert model) - contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model) - contains `xlm`: TFXLMForQuestionAnswering (XLM model) @@ -849,6 +855,8 @@ class TFAutoModelForQuestionAnswering(object): config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: The model class to instantiate is selected based on the configuration class: - isInstance of `distilbert` configuration class: DistilBertModel (DistilBERT model) + - isInstance of `albert` configuration class: AlbertModel (ALBERT model) + - isInstance of `roberta` configuration class: RobertaModel (RoBERTa model) - isInstance of `bert` configuration class: BertModel (Bert model) - isInstance of `xlnet` configuration class: XLNetModel (XLNet model) - isInstance of `xlm` configuration class: XLMModel (XLM model) @@ -856,7 +864,7 @@ class TFAutoModelForQuestionAnswering(object): Examples:: config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. - model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = TFAutoModelForQuestionAnswering.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` """ for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if isinstance(config, config_class): @@ -882,6 +890,8 @@ class TFAutoModelForQuestionAnswering(object): The model class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): - contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model) + - contains `albert`: TFAlbertForQuestionAnswering (ALBERT model) + - contains `roberta`: TFRobertaForQuestionAnswering (RoBERTa model) - contains `bert`: TFBertForQuestionAnswering (Bert model) - contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model) - contains `xlm`: TFXLMForQuestionAnswering (XLM model) diff --git a/src/transformers/modeling_tf_roberta.py b/src/transformers/modeling_tf_roberta.py index 31fb43f1cc6..abe6e844e69 100644 --- a/src/transformers/modeling_tf_roberta.py +++ b/src/transformers/modeling_tf_roberta.py @@ -442,3 +442,68 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel): outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here return outputs # scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, + ROBERTA_START_DOCSTRING, +) +class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFRobertaMainLayer(config, name="roberta") + self.qa_outputs = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + def call(self, inputs, **kwargs): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`): + tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``): + tuple of :obj:`tf.Tensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`: + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + # The checkpoint roberta-base is not fine-tuned for question answering. Please see the + # examples/run_squad.py example to see how to fine-tune a model to a question answering task. + + import tensorflow as tf + from transformers import RobertaTokenizer, TFRobertaForQuestionAnswering + + tokenizer = RobertaTokenizer.from_pretrained('roberta-base') + model = TFRobertaForQuestionAnswering.from_pretrained('roberta-base') + input_ids = tokenizer.encode("Who was Jim Henson?", "Jim Henson was a nice puppet") + start_scores, end_scores = model(tf.constant(input_ids)[None, :]) # Batch size 1 + + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1]) + + """ + outputs = self.roberta(inputs, **kwargs) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + + return outputs # start_logits, end_logits, (hidden_states), (attentions) diff --git a/tests/test_modeling_tf_albert.py b/tests/test_modeling_tf_albert.py index fb7b269cdca..00ddcc45dc9 100644 --- a/tests/test_modeling_tf_albert.py +++ b/tests/test_modeling_tf_albert.py @@ -28,6 +28,7 @@ if is_tf_available(): TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, + TFAlbertForQuestionAnswering, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ) @@ -36,7 +37,9 @@ if is_tf_available(): class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( - (TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification) if is_tf_available() else () + (TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, TFAlbertForQuestionAnswering) + if is_tf_available() + else () ) class TFAlbertModelTester(object): @@ -175,6 +178,19 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): } self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels]) + def create_and_check_albert_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFAlbertForQuestionAnswering(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + start_logits, end_logits = model(inputs) + result = { + "start_logits": start_logits.numpy(), + "end_logits": end_logits.numpy(), + } + self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -208,6 +224,10 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_albert_for_sequence_classification(*config_and_inputs) + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: diff --git a/tests/test_modeling_tf_roberta.py b/tests/test_modeling_tf_roberta.py index 9bc837c4e3c..7432cc36074 100644 --- a/tests/test_modeling_tf_roberta.py +++ b/tests/test_modeling_tf_roberta.py @@ -31,6 +31,7 @@ if is_tf_available(): TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TFRobertaForTokenClassification, + TFRobertaForQuestionAnswering, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ) @@ -39,7 +40,15 @@ if is_tf_available(): class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = ( - (TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification) if is_tf_available() else () + ( + TFRobertaModel, + TFRobertaForMaskedLM, + TFRobertaForSequenceClassification, + TFRobertaForTokenClassification, + TFRobertaForQuestionAnswering, + ) + if is_tf_available() + else () ) class TFRobertaModelTester(object): @@ -171,6 +180,19 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels] ) + def create_and_check_roberta_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = TFRobertaForQuestionAnswering(config=config) + inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids} + start_logits, end_logits = model(inputs) + result = { + "start_logits": start_logits.numpy(), + "end_logits": end_logits.numpy(), + } + self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length]) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -204,6 +226,10 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs) + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in list(TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: