transformers/tests/test_modeling_tf_roberta.py
Yih-Dar 8b240a0661
Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222)
* Add cross attentions to TFGPT2Model

* Add TFEncoderDecoderModel

* Add TFBaseModelOutputWithPoolingAndCrossAttentions

* Add cross attentions to TFBertModel

* Fix past or past_key_values argument issue

* Fix generation

* Fix save and load

* Add some checks and comments

* Clean the code that deals with past keys/values

* Add kwargs to processing_inputs

* Add serving_output to TFEncoderDecoderModel

* Some cleaning + fix use_cache value issue

* Fix tests + add bert2bert/bert2gpt2 tests

* Fix more tests

* Ignore crossattention.bias when loading GPT2 weights into TFGPT2

* Fix return_dict_in_generate in tf generation

* Fix is_token_logit_eos_token bug in tf generation

* Finalize the tests after fixing some bugs

* Fix another is_token_logit_eos_token bug in tf generation

* Add/Update docs

* Add TFBertEncoderDecoderModelTest

* Clean test script

* Add TFEncoderDecoderModel to the library

* Add cross attentions to TFRobertaModel

* Add TFRobertaEncoderDecoderModelTest

* make style

* Change the way of position_ids computation

* bug fix

* Fix copies in tf_albert

* Remove some copied from and apply some fix-copies

* Remove some copied

* Add cross attentions to some other TF models

* Remove encoder_hidden_states from TFLayoutLMModel.call for now

* Make style

* Fix TFRemBertForCausalLM

* Revert the change to longformer + Remove copies

* Revert the change to albert and convbert + Remove copies

* make quality

* make style

* Add TFRembertEncoderDecoderModelTest

* make quality and fix-copies

* test TFRobertaForCausalLM

* Fixes for failed tests

* Fixes for failed tests

* fix more tests

* Fixes for failed tests

* Fix Auto mapping order

* Fix TFRemBertEncoder return value

* fix tf_rembert

* Check copies are OK

* Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined

* Add TFEncoderDecoderModelSaveLoadTests

* fix tf weight loading

* check the change of use_cache

* Revert the change

* Add missing test_for_causal_lm for TFRobertaModelTest

* Try cleaning past

* fix _reorder_cache

* Revert some files to original versions

* Keep as many copies as possible

* Apply suggested changes - Use raise ValueError instead of assert

* Move import to top

* Fix wrong require_torch

* Replace more assert by raise ValueError

* Add test_pt_tf_model_equivalence (the test won't pass for now)

* add test for loading/saving

* finish

* finish

* Remove test_pt_tf_model_equivalence

* Update tf modeling template

* Remove pooling, added in the prev. commit, from MainLayer

* Update tf modeling test template

* Move inputs["use_cache"] = False to modeling_tf_utils.py

* Fix torch.Tensor in the comment

* fix use_cache

* Fix missing use_cache in ElectraConfig

* Add a note to from_pretrained

* Fix style

* Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt

* Fix TFMLP (in TFGPT2) activation issue

* Fix None past_key_values value in serving_output

* Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub

* Apply review suggestions - style for cross_attns in serving_output

* Apply review suggestions - change assert + docstrings

* break the error message to respect the char limit

* deprecate the argument past

* fix docstring style

* Update the encoder-decoder rst file

* fix Unknown interpreted text role "method"

* fix typo

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2021-10-13 00:10:34 +02:00

305 lines
12 KiB
Python

# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import RobertaConfig, is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import numpy
import tensorflow as tf
from transformers.models.roberta.modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaModel,
)
class TFRobertaModelTester:
def __init__(
self,
parent,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = True
self.use_input_mask = True
self.use_token_type_ids = True
self.use_labels = True
self.vocab_size = 99
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.scope = None
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = RobertaConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_roberta_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaModel(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
inputs = [input_ids, input_mask]
result = model(inputs)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_roberta_for_causal_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForCausalLM(config=config)
result = model([input_ids, input_mask, token_type_ids])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_roberta_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForMaskedLM(config=config)
result = model([input_ids, input_mask, token_type_ids])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_roberta_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFRobertaForTokenClassification(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_roberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFRobertaForQuestionAnswering(config=config)
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
result = model(inputs)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFRobertaForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
inputs = {
"input_ids": multiple_choice_inputs_ids,
"attention_mask": multiple_choice_input_mask,
"token_type_ids": multiple_choice_token_type_ids,
}
result = model(inputs)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_tf
class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
TFRobertaModel,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TFRobertaForQuestionAnswering,
)
if is_tf_available()
else ()
)
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFRobertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_roberta_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_masked_lm(*config_and_inputs)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_causal_lm(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_token_classification(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_question_answering(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_roberta_for_multiple_choice(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFRobertaModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_tf
@require_sentencepiece
@require_tokenizers
class TFRobertaModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = TFRobertaForMaskedLM.from_pretrained("roberta-base")
input_ids = tf.constant([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = [1, 11, 50265]
self.assertEqual(list(output.numpy().shape), expected_shape)
# compare the actual values for a slice.
expected_slice = tf.constant(
[[[33.8802, -4.3103, 22.7761], [4.6539, -2.8098, 13.6253], [1.8228, -3.6898, 8.8600]]]
)
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
@slow
def test_inference_no_head(self):
model = TFRobertaModel.from_pretrained("roberta-base")
input_ids = tf.constant([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = tf.constant(
[[[-0.0231, 0.0782, 0.0074], [-0.1854, 0.0540, -0.0175], [0.0548, 0.0799, 0.1687]]]
)
self.assertTrue(numpy.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
@slow
def test_inference_classification_head(self):
model = TFRobertaForSequenceClassification.from_pretrained("roberta-large-mnli")
input_ids = tf.constant([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = [1, 3]
self.assertEqual(list(output.numpy().shape), expected_shape)
expected_tensor = tf.constant([[-0.9469, 0.3913, 0.5118]])
self.assertTrue(numpy.allclose(output.numpy(), expected_tensor.numpy(), atol=1e-4))