mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
modeling_bart: 3 small cleanups that dont change outputs (#7381)
* Mbart passing * boom boom * cleaner assert * add assert * Fix tests
This commit is contained in:
parent
9e68d075a4
commit
3c6bf8998f
@ -307,7 +307,7 @@ class BartEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||
# mbart has one extra layer_norm
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
|
||||
self.layer_norm = LayerNorm(config.d_model) if config.add_final_layer_norm else None
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=False
|
||||
@ -550,6 +550,10 @@ class BartDecoder(nn.Module):
|
||||
positions = self.embed_positions(input_ids, use_cache=use_cache)
|
||||
|
||||
if use_cache:
|
||||
if input_ids.shape[1] != 1 or past_key_values is None:
|
||||
# if you make this an AssertionError, test_benchmark breaks.
|
||||
warnings.warn("pass decoder_past_key_value_states to use_cache")
|
||||
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
# assert input_ids.ne(self.padding_idx).any()
|
||||
@ -590,11 +594,12 @@ class BartDecoder(nn.Module):
|
||||
if use_cache:
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
|
||||
if self.layer_norm and (idx == len(self.layers) - 1): # if config.add_final_layer_norm (mBART)
|
||||
x = self.layer_norm(x)
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_self_attn,)
|
||||
|
||||
if self.layer_norm: # if config.add_final_layer_norm (mBART)
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = tuple(hidden_state.transpose(0, 1) for hidden_state in all_hidden_states)
|
||||
|
@ -86,8 +86,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
batch: BatchEncoding = self.tokenizer.prepare_seq2seq_batch(self.src_text).to(torch_device)
|
||||
translated_tokens = self.model.generate(**batch)
|
||||
decoded = self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
||||
self.assertEqual(self.tgt_text[0], decoded[0])
|
||||
self.assertEqual(self.tgt_text[1], decoded[1])
|
||||
assert self.tgt_text == decoded
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
|
Loading…
Reference in New Issue
Block a user