mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[BART] Remove unused kwargs (#3279)
* Remove unused kwargs * dont call forward in tests
This commit is contained in:
parent
3814e167d9
commit
5ea8ba67b4
@ -844,7 +844,7 @@ class Translator(object):
|
|||||||
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
||||||
|
|
||||||
# Generator forward.
|
# Generator forward.
|
||||||
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0))
|
||||||
vocab_size = log_probs.size(-1)
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
if step < min_length:
|
if step < min_length:
|
||||||
|
@ -223,9 +223,7 @@ class EncoderLayer(nn.Module):
|
|||||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||||
"""
|
"""
|
||||||
residual = x
|
residual = x
|
||||||
x, attn_weights = self.self_attn(
|
x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,)
|
||||||
query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions,
|
|
||||||
)
|
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
@ -378,7 +376,7 @@ class DecoderLayer(nn.Module):
|
|||||||
layer_state = {}
|
layer_state = {}
|
||||||
# next line mutates layer state
|
# next line mutates layer state
|
||||||
x, self_attn_weights = self.self_attn(
|
x, self_attn_weights = self.self_attn(
|
||||||
query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask,
|
query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask,
|
||||||
)
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
@ -393,7 +391,6 @@ class DecoderLayer(nn.Module):
|
|||||||
key_padding_mask=encoder_attn_mask,
|
key_padding_mask=encoder_attn_mask,
|
||||||
layer_state=layer_state, # mutates layer state
|
layer_state=layer_state, # mutates layer state
|
||||||
static_kv=True,
|
static_kv=True,
|
||||||
need_weights=False, # not returning it so why compute it
|
|
||||||
)
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
@ -548,16 +545,12 @@ class SelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
kdim=None,
|
|
||||||
vdim=None,
|
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
bias=True,
|
bias=True,
|
||||||
encoder_decoder_attention=False, # otherwise self_attention
|
encoder_decoder_attention=False, # otherwise self_attention
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.kdim = kdim if kdim is not None else embed_dim
|
|
||||||
self.vdim = vdim if vdim is not None else embed_dim
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
@ -566,13 +559,8 @@ class SelfAttention(nn.Module):
|
|||||||
self.scaling = self.head_dim ** -0.5
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
self.encoder_decoder_attention = encoder_decoder_attention
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim # True for all BART
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
assert self.encoder_decoder_attention or qkv_same_dim, (
|
|
||||||
"Self-attention requires query, key and " "value to be of the same size"
|
|
||||||
)
|
|
||||||
self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
|
|
||||||
self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
|
|
||||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
||||||
@ -587,7 +575,6 @@ class SelfAttention(nn.Module):
|
|||||||
value: Optional[Tensor],
|
value: Optional[Tensor],
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||||
need_weights: bool = False,
|
|
||||||
static_kv: bool = False,
|
static_kv: bool = False,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
@ -598,8 +585,6 @@ class SelfAttention(nn.Module):
|
|||||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||||
keys that are pads, of shape `(batch, src_len)`, where
|
keys that are pads, of shape `(batch, src_len)`, where
|
||||||
padding elements are indicated by 1s.
|
padding elements are indicated by 1s.
|
||||||
need_weights (bool, optional): return the attention weights,
|
|
||||||
averaged over heads (default: False).
|
|
||||||
attn_mask (ByteTensor, optional): typically used to
|
attn_mask (ByteTensor, optional): typically used to
|
||||||
implement causal attention, where the mask prevents the
|
implement causal attention, where the mask prevents the
|
||||||
attention from looking forward in time (default: None).
|
attention from looking forward in time (default: None).
|
||||||
|
@ -141,13 +141,13 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
_check_var(model.encoder.layers[0].fc1)
|
_check_var(model.encoder.layers[0].fc1)
|
||||||
_check_var(model.encoder.embed_positions)
|
_check_var(model.encoder.embed_positions)
|
||||||
|
|
||||||
decoder_features_with_created_mask = model.forward(**inputs_dict)[0]
|
decoder_features_with_created_mask = model(**inputs_dict)[0]
|
||||||
decoder_features_with_passed_mask = model.forward(
|
decoder_features_with_passed_mask = model(
|
||||||
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
decoder_attention_mask=decoder_attn_mask, decoder_input_ids=decoder_input_ids, **inputs_dict
|
||||||
)[0]
|
)[0]
|
||||||
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
_assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
|
||||||
useless_mask = torch.zeros_like(decoder_attn_mask)
|
useless_mask = torch.zeros_like(decoder_attn_mask)
|
||||||
decoder_features = model.forward(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
|
||||||
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
self.assertTrue(isinstance(decoder_features, torch.Tensor)) # no hidden states or attentions
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
|
decoder_features.size(), (self.model_tester.batch_size, self.model_tester.seq_length, config.d_model)
|
||||||
@ -156,7 +156,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
|
self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())
|
||||||
|
|
||||||
# Test different encoder attention masks
|
# Test different encoder attention masks
|
||||||
decoder_features_with_long_encoder_mask = model.forward(
|
decoder_features_with_long_encoder_mask = model(
|
||||||
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
|
||||||
)[0]
|
)[0]
|
||||||
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
|
||||||
@ -237,7 +237,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size).to(torch_device)
|
||||||
lm_model = BartForConditionalGeneration(config)
|
lm_model = BartForConditionalGeneration(config)
|
||||||
lm_model.to(torch_device)
|
lm_model.to(torch_device)
|
||||||
loss, logits, enc_features = lm_model.forward(
|
loss, logits, enc_features = lm_model(
|
||||||
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
|
||||||
)
|
)
|
||||||
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
|
||||||
@ -259,7 +259,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
lm_model = BartForConditionalGeneration(config).to(torch_device)
|
||||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
|
||||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||||
loss, logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary)
|
||||||
expected_shape = (*summary.shape, config.vocab_size)
|
expected_shape = (*summary.shape, config.vocab_size)
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
|
||||||
@ -388,7 +388,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model.forward(**inputs_dict)[0]
|
output = model(**inputs_dict)[0]
|
||||||
expected_shape = torch.Size((1, 11, 1024))
|
expected_shape = torch.Size((1, 11, 1024))
|
||||||
self.assertEqual(output.shape, expected_shape)
|
self.assertEqual(output.shape, expected_shape)
|
||||||
expected_slice = torch.tensor(
|
expected_slice = torch.tensor(
|
||||||
@ -408,7 +408,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids)
|
||||||
# Test that model hasn't changed
|
# Test that model hasn't changed
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batched_logits, features = model.forward(**inputs_dict)
|
batched_logits, features = model(**inputs_dict)
|
||||||
expected_shape = torch.Size((2, 3))
|
expected_shape = torch.Size((2, 3))
|
||||||
self.assertEqual(batched_logits.shape, expected_shape)
|
self.assertEqual(batched_logits.shape, expected_shape)
|
||||||
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device)
|
expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device)
|
||||||
@ -419,7 +419,7 @@ class BartModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
inputs_dict = prepare_bart_inputs_dict(model.config, input_ids=input_ids_no_pad)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits2 = model.forward(**inputs_dict)[0]
|
logits2 = model(**inputs_dict)[0]
|
||||||
_assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE)
|
_assert_tensors_equal(batched_logits[1], logits2, atol=TOLERANCE)
|
||||||
_assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)
|
_assert_tensors_equal(expected_slice, logits_arr, atol=TOLERANCE)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user