transformers/tests/test_modeling_encoder_decoder.py
Patrick von Platen 1d6e71e116
[EncoderDecoder] Add Cross Attention for GPT2 (#6415)
* add cross attention layers for gpt2

* make gpt2 cross attention work

* finish bert2gpt2

* add explicit comments

* remove attention mask since not yet supported

* revert attn mask in pipeline

* Update src/transformers/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-08-14 09:43:29 +02:00

483 lines
19 KiB
Python

# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor
from .test_modeling_gpt2 import GPT2ModelTester
from .test_modeling_roberta import RobertaModelTester
if is_torch_available():
from transformers import (
BertModel,
BertLMHeadModel,
GPT2LMHeadModel,
RobertaModel,
RobertaForCausalLM,
EncoderDecoderModel,
EncoderDecoderConfig,
)
import numpy as np
import torch
@require_torch
class EncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config):
pass
def prepare_config_and_inputs(self):
pass
def get_pretrained_model(self):
pass
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = EncoderDecoderModel(encoder_decoder_config)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def check_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
encoder_outputs = (encoder_hidden_states,)
outputs_encoder_decoder = enc_dec_model(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def check_encoder_decoder_model_from_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def check_save_and_load(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
with torch.no_grad():
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
EncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_save_and_load_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
with torch.no_grad():
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
)
mlm_loss = outputs_encoder_decoder[0]
# check that backprop works
mlm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)
def test_save_and_load_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load_encoder_decoder_model(**input_ids_dict)
def test_encoder_decoder_model_labels(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_labels(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
model_2.to(torch_device)
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
with torch.no_grad():
outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname)
model_1.to(torch_device)
after_outputs = model_1(
input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = BertLMHeadModel(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester = BertModelTester(self)
encoder_config_and_inputs = model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_input_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = RobertaModel(config)
decoder_model = RobertaForCausalLM(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester = RobertaModelTester(self)
encoder_config_and_inputs = model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_input_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = GPT2LMHeadModel(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = BertModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_input_mask,
decoder_head_mask,
decoder_token_type_ids,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")