modeling_bart: 3 small cleanups that dont change outputs (#7381)

* Mbart passing

* boom boom

* cleaner assert

* add assert

* Fix tests
This commit is contained in:
Sam Shleifer 2020-09-25 04:24:14 -04:00 committed by GitHub
parent 9e68d075a4
commit 3c6bf8998f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View File

@ -307,7 +307,7 @@ class BartEncoder(nn.Module):
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
def forward(
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
@ -550,6 +550,10 @@ class BartDecoder(nn.Module):
positions = self.embed_positions(input_ids, use_cache=use_cache)
if use_cache:
if input_ids.shape[1] != 1 or past_key_values is None:
# if you make this an AssertionError, test_benchmark breaks.
warnings.warn("pass decoder_past_key_value_states to use_cache")
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
# assert input_ids.ne(self.padding_idx).any()
@ -590,11 +594,12 @@ class BartDecoder(nn.Module):
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
if output_attentions:
all_self_attns += (layer_self_attn,)
if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states:
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)

View File

@ -86,8 +86,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
translated_tokens = self.model.generate(**batch)
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
self.assertEqual(self.tgt_text[0], decoded[0])
self.assertEqual(self.tgt_text[1], decoded[1])
assert self.tgt_text == decoded
def test_mbart_enro_config(self):
mbart_models = ["facebook/mbart-large-en-ro"]