[Bart] Refactor - fix issues, consistency with the library, naming (#8900)

* remove make on the fly linear embedding

* start refactor

* big first refactor

* save intermediate

* save intermediat

* correct mask issue

* save tests

* refactor padding masks

* make all tests pass

* further refactor

* make pegasus test pass

* fix bool if

* fix leftover tests

* continue

* bart renaming

* delete torchscript test hack

* fix imports in tests

* correct shift

* fix docs and repo cons

* re-add fix for FSTM

* typo in test

* fix typo

* fix another typo

* continue

* hot fix 2 for tf

* small fixes

* refactor types linting

* continue

* finish refactor

* fix import in tests

* better bart names

* further refactor and add test

* delete hack

* apply sylvains and lysandres commens

* small perf improv

* further perf improv

* improv perf

* fix typo

* make style

* small perf improv
This commit is contained in:
Patrick von Platen 2020-12-09 20:55:24 +01:00 committed by GitHub
parent 75627148ee
commit 06971ac4f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1079 additions and 823 deletions

View File

@ -105,8 +105,6 @@ BartModel
.. autoclass:: transformers.BartModel
:members: forward
.. autofunction:: transformers.models.bart.modeling_bart._prepare_bart_decoder_inputs
BartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -378,6 +378,7 @@ if is_torch_available():
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
BartPretrainedModel,
PretrainedBartModel,
)
from .models.bert import (

View File

@ -31,6 +31,7 @@ if is_torch_available():
BartForQuestionAnswering,
BartForSequenceClassification,
BartModel,
BartPretrainedModel,
PretrainedBartModel,
)

File diff suppressed because it is too large Load Diff

View File

@ -577,9 +577,9 @@ class TFBartDecoder(tf.keras.layers.Layer):
encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions
positions = self.embed_positions(input_ids, use_cache=use_cache)
positions = self.embed_positions(input_ids, use_cache=(use_cache and decoder_cached_states is not None))
if use_cache:
if use_cache and decoder_cached_states is not None:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:]
@ -964,7 +964,7 @@ class TFBartModel(TFPretrainedBartModel):
else self.config.output_hidden_states
)
if not inputs["use_cache"]:
if not use_cache or past_key_values is None:
inputs["decoder_input_ids"], decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs["input_ids"],
decoder_input_ids=inputs["decoder_input_ids"],
@ -1154,6 +1154,7 @@ class TFBartForConditionalGeneration(TFPretrainedBartModel):
assert (
decoder_cached_states
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."

View File

@ -813,9 +813,6 @@ class T5Stack(T5PreTrainedModel):
def get_input_embeddings(self):
return self.embed_tokens
def get_output_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings

View File

@ -450,6 +450,15 @@ class BartModel:
requires_pytorch(self)
class BartPretrainedModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class PretrainedBartModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
@ -49,37 +50,68 @@ if is_torch_available():
pipeline,
)
from transformers.models.bart.modeling_bart import (
SinusoidalPositionalEmbedding,
_prepare_bart_decoder_inputs,
invert_mask,
BartDecoder,
BartEncoder,
BartSinusoidalPositionalEmbedding,
shift_tokens_right,
)
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
def prepare_bart_inputs_dict(
config,
input_ids,
attention_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
@require_torch
class ModelTester:
class BartModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = True
self.use_labels = False
self.vocab_size = 99
self.hidden_size = 16
self.num_hidden_layers = 2
self.num_attention_heads = 4
self.intermediate_size = 4
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20
self.eos_token_id = 2
self.pad_token_id = 1
self.bos_token_id = 0
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
torch.manual_seed(0)
def prepare_config_and_inputs(self):
@ -111,21 +143,67 @@ class ModelTester:
config, inputs_dict = self.prepare_config_and_inputs()
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict["decoder_attention_mask"] = inputs_dict["attention_mask"]
inputs_dict["use_cache"] = False
return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = BartModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["input_ids"]
def prepare_bart_inputs_dict(
config,
input_ids,
attention_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# first forward pass
outputs = model(input_ids, use_cache=True)
output, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = BartModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = BartEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
0
]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = BartDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
encoder_attention_mask=inputs_dict["attention_mask"],
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
@require_torch
@ -142,7 +220,7 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
test_missing_keys = False
def setUp(self):
self.model_tester = ModelTester(self)
self.model_tester = BartModelTester(self)
self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self):
@ -169,23 +247,25 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_cache = False
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, inputs_dict["input_ids"]
)
model = BartModel(config).to(torch_device).eval()
model = BartModel(config).to(torch_device).eval()
decoder_features_with_created_mask = model(**inputs_dict)[0]
decoder_input_ids = shift_tokens_right(inputs_dict["input_ids"], config.pad_token_id)
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
decoder_attention_mask[:, 0] = decoder_attention_mask[:, 1]
decoder_features_with_passed_mask = model(
decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
decoder_attention_mask=decoder_attention_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
)[0]
assert_tensors_close(decoder_features_with_passed_mask, decoder_features_with_created_mask)
useless_mask = torch.zeros_like(decoder_attn_mask)
useless_mask = torch.zeros_like(decoder_attention_mask)
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
self.assertEqual(
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
)
if decoder_attn_mask.min().item() < -1e3: # some tokens were masked
if decoder_attention_mask.min().item() == 0: # some tokens were masked
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
# Test different encoder attention masks
@ -204,13 +284,43 @@ class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
@unittest.skip("Passing inputs_embeds not implemented for Bart.")
def test_inputs_embeds(self):
pass
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
@unittest.skip("TODO: Decoder embeddings cannot be resized at the moment")
def test_resize_embeddings_untied(self):
pass
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
# BartForSequenceClassification does not support inputs_embeds
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in (BartModel, BartForConditionalGeneration, BartForQuestionAnswering):
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
inputs["inputs_embeds"] = wte(input_ids)
else:
inputs["inputs_embeds"] = wte(encoder_input_ids)
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs)[0]
@require_sentencepiece
@require_tokenizers
@ -386,20 +496,6 @@ class BartHeadTests(unittest.TestCase):
model = BartForConditionalGeneration(config).eval().to(torch_device)
model(**model.dummy_inputs)
def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data()
input_ids = _long_tensor(([4, 4, 2]))
decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
ignore = float("-inf")
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids
)
expected_causal_mask = torch.tensor(
[[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad
).to(input_ids.device)
self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())
def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data()
@ -470,14 +566,14 @@ class BartModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
@slow
def test_bart_base_mask_filling(self):
def test_base_mask_filling(self):
pbase = pipeline(task="fill-mask", model="facebook/bart-base")
src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in pbase(src_text)]
assert "Ġbathroom" in results
@slow
def test_bart_large_mask_filling(self):
def test_large_mask_filling(self):
plarge = pipeline(task="fill-mask", model="facebook/bart-large")
src_text = [" I went to the <mask>."]
results = [x["token_str"] for x in plarge(src_text)]
@ -608,7 +704,7 @@ class BartModelIntegrationTests(unittest.TestCase):
@require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
class TestBartSinusoidalPositionalEmbeddings(unittest.TestCase):
desired_weights = [
[0, 0, 0, 0, 0],
[0.84147096, 0.82177866, 0.80180490, 0.78165019, 0.76140374],
@ -616,38 +712,30 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
]
def test_positional_emb_cache_logic(self):
pad = 1
input_ids = torch.tensor([[4, 10]], dtype=torch.long, device=torch_device)
emb1 = SinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=pad).to(torch_device)
no_cache = emb1(input_ids, use_cache=False)
yes_cache = emb1(input_ids, use_cache=True)
self.assertEqual((1, 1, 6), yes_cache.shape) # extra dim to allow broadcasting, feel free to delete!
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
emb1 = BartSinusoidalPositionalEmbedding(num_positions=32, embedding_dim=6, padding_idx=1).to(torch_device)
no_cache = emb1((4, 10), past_key_values_length=0)
yes_cache = emb1((4, 10), past_key_values_length=2)
self.assertTrue(no_cache.shape == yes_cache.shape == (10, 6))
self.assertListEqual(no_cache[2:].tolist(), yes_cache[:-2].tolist())
def test_odd_embed_dim(self):
# odd embedding_dim is allowed
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
BartSinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
# odd num_positions is allowed
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
BartSinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
def test_positional_emb_weights_against_marian(self):
pad = 1
emb1 = SinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(torch_device)
emb1 = BartSinusoidalPositionalEmbedding(num_positions=512, embedding_dim=512, padding_idx=pad).to(
torch_device
)
weights = emb1.weight.data[:3, :5].tolist()
for i, (expected_weight, actual_weight) in enumerate(zip(self.desired_weights, weights)):
for j in range(5):
self.assertAlmostEqual(expected_weight[j], actual_weight[j], places=3)
# test that forward pass is just a lookup, there is no ignore padding logic
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
no_cache_pad_zero = emb1(input_ids)
self.assertTrue(
torch.allclose(
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
)
)
def test_child_config_equivalence(self):
"""Test that configs associated with children of BartForConditionalGeneration are identical."""
child_classes = [BlenderbotConfig, MBartConfig, MarianConfig, PegasusConfig]

View File

@ -104,9 +104,6 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
self.model_tester = BlenderbotModelTester(self)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_inputs_embeds(self):
pass
def test_initialization_module(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = BlenderbotForConditionalGeneration(config).model

View File

@ -302,6 +302,8 @@ class ModelTesterMixin:
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen)
@ -386,7 +388,7 @@ class ModelTesterMixin:
try:
if model.config.is_encoder_decoder:
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
@ -1020,7 +1022,6 @@ class ModelTesterMixin:
)
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:

View File

@ -37,7 +37,7 @@ from transformers.testing_utils import (
torch_device,
)
from .test_modeling_bart import ModelTester as BartModelTester
from .test_modeling_bart import BartModelTester
from .test_modeling_dpr import DPRModelTester
from .test_modeling_t5 import T5ModelTester

View File

@ -344,13 +344,6 @@ class TFModelTesterMixin:
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 4e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs_dict)
self.assertLessEqual(max_diff, 4e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions

View File

@ -28,6 +28,8 @@ PATH_TO_DOC = "docs/source/model_doc"
# Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [
"BartDecoder", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder.
"DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model.
@ -58,9 +60,11 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# Update this list for models that are not documented with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule.
IGNORE_NON_DOCUMENTED = [
"BartDecoder", # Building part of bigger (documented) model.
"BartEncoder", # Building part of bigger (documented) model.
"DPREncoder", # Building part of bigger (documented) model.
"DPRSpanPredictor", # Building part of bigger (documented) model.
"T5Stack", # Building part of bigger (tested) model.
"T5Stack", # Building part of bigger (documented) model.
"TFDPREncoder", # Building part of bigger (documented) model.
"TFDPRSpanPredictor", # Building part of bigger (documented) model.
]
@ -78,6 +82,8 @@ MODEL_NAME_TO_DOC_FILE = {
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [
"BartDecoder",
"BartEncoder",
"DPRContextEncoder",
"DPREncoder",
"DPRReader",