mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[BART] cleanup: remove redundant kwargs, improve docstrings (#3319)
This commit is contained in:
parent
cd21d8bc00
commit
ad7233fc01
@ -56,7 +56,7 @@ BART_GENERATION_EXAMPLE = r"""
|
|||||||
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
|
||||||
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
|
||||||
# Generate Summary
|
# Generate Summary
|
||||||
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5)
|
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
|
||||||
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -84,8 +84,9 @@ LARGE_NEGATIVE = -1e8
|
|||||||
def _prepare_bart_decoder_inputs(
|
def _prepare_bart_decoder_inputs(
|
||||||
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
|
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
|
||||||
):
|
):
|
||||||
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
|
"""Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if
|
||||||
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
|
none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
|
||||||
|
Note: this is not called during generation
|
||||||
"""
|
"""
|
||||||
pad_token_id = config.pad_token_id
|
pad_token_id = config.pad_token_id
|
||||||
need_causal_mask = not config.output_past
|
need_causal_mask = not config.output_past
|
||||||
@ -114,8 +115,6 @@ class PretrainedBartModel(PreTrainedModel):
|
|||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
|
|
||||||
# called init_bert_params in fairseq
|
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
module.weight.data.normal_(mean=0.0, std=std)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
@ -127,16 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_inputs(self):
|
def dummy_inputs(self):
|
||||||
pad_token = 1
|
pad_token = self.config.pad_token_id
|
||||||
input_ids = torch.Tensor(
|
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
|
||||||
[
|
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,)
|
||||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
|
|
||||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 2, pad_token],
|
|
||||||
]
|
|
||||||
).long()
|
|
||||||
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
|
|
||||||
self.config, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attn_mask=None
|
|
||||||
)
|
|
||||||
dummy_inputs = {
|
dummy_inputs = {
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": input_ids.ne(pad_token),
|
"attention_mask": input_ids.ne(pad_token),
|
||||||
@ -149,7 +141,7 @@ class PretrainedBartModel(PreTrainedModel):
|
|||||||
def _make_linear_from_emb(emb):
|
def _make_linear_from_emb(emb):
|
||||||
vocab_size, emb_size = emb.weight.shape
|
vocab_size, emb_size = emb.weight.shape
|
||||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||||
lin_layer.weight.data = emb.weight.data # .T
|
lin_layer.weight.data = emb.weight.data
|
||||||
return lin_layer
|
return lin_layer
|
||||||
|
|
||||||
|
|
||||||
@ -160,8 +152,8 @@ def _check_shapes(shape_1, shape2):
|
|||||||
|
|
||||||
|
|
||||||
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
|
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
|
||||||
# targ_size = (bsz, tgt_len, src_len)
|
"""Make one mask of shape (bsz, 1, tgt_len, src_len) """
|
||||||
a = torch.zeros(targ_size)
|
a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len)
|
||||||
b = torch.zeros(targ_size)
|
b = torch.zeros(targ_size)
|
||||||
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
|
||||||
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
_check_shapes(key_padding_mask.shape, targ_size[:2])
|
||||||
@ -223,7 +215,7 @@ class EncoderLayer(nn.Module):
|
|||||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||||
"""
|
"""
|
||||||
residual = x
|
residual = x
|
||||||
x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,)
|
x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
@ -266,7 +258,7 @@ class BartEncoder(nn.Module):
|
|||||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids=None, attention_mask=None,
|
self, input_ids, attention_mask=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -274,21 +266,19 @@ class BartEncoder(nn.Module):
|
|||||||
`(batch, src_len)`
|
`(batch, src_len)`
|
||||||
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
|
attention_mask (torch.LongTensor): indicating which indices are padding tokens.
|
||||||
Returns:
|
Returns:
|
||||||
namedtuple:
|
Tuple comprised of:
|
||||||
- **x** (Tensor): the last encoder layer's output of
|
- **x** (Tensor): the last encoder layer's output of
|
||||||
shape `(src_len, batch, embed_dim)`
|
shape `(src_len, batch, embed_dim)`
|
||||||
|
|
||||||
- **encoder_states** (List[Tensor]): all intermediate
|
- **encoder_states** (List[Tensor]): all intermediate
|
||||||
hidden states of shape `(src_len, batch, embed_dim)`.
|
hidden states of shape `(src_len, batch, embed_dim)`.
|
||||||
Only populated if *return_all_hiddens* is True.
|
Only populated if *self.output_hidden_states:* is True.
|
||||||
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
- **all_attentions** (List[Tensor]): Attention weights for each layer.
|
||||||
During training might not be of length n_layers because of layer dropout.
|
During training might not be of length n_layers because of layer dropout.
|
||||||
"""
|
"""
|
||||||
# check attention mask and invert
|
# check attention mask and invert
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
assert attention_mask.dim() == 2
|
assert attention_mask.dim() == 2
|
||||||
|
attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE
|
||||||
attention_mask = (1.0 - attention_mask.long()) * -10000.0
|
|
||||||
assert attention_mask.max() <= 0
|
assert attention_mask.max() <= 0
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
embed_pos = self.embed_positions(input_ids)
|
embed_pos = self.embed_positions(input_ids)
|
||||||
@ -300,10 +290,7 @@ class BartEncoder(nn.Module):
|
|||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
encoder_states, all_attentions = [], []
|
encoder_states, all_attentions = [], []
|
||||||
|
|
||||||
# encoder layers
|
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
|
|
||||||
if self.output_hidden_states:
|
if self.output_hidden_states:
|
||||||
encoder_states.append(x)
|
encoder_states.append(x)
|
||||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||||
@ -320,7 +307,6 @@ class BartEncoder(nn.Module):
|
|||||||
encoder_states.append(x)
|
encoder_states.append(x)
|
||||||
|
|
||||||
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
|
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
|
||||||
|
|
||||||
return x, encoder_states, all_attentions
|
return x, encoder_states, all_attentions
|
||||||
|
|
||||||
|
|
||||||
@ -356,28 +342,12 @@ class DecoderLayer(nn.Module):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
need_attn_weights=False,
|
need_attn_weights=False,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
|
||||||
encoder_attn_mask (ByteTensor, optional): binary
|
|
||||||
ByteTensor of shape `(batch, src_len)` where padding
|
|
||||||
elements are indicated by ``1``.
|
|
||||||
need_attn_weights (bool, optional): return attention weights
|
|
||||||
for each head (default: return average over heads).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = x
|
residual = x
|
||||||
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
|
|
||||||
|
|
||||||
if layer_state is None:
|
if layer_state is None:
|
||||||
layer_state = {}
|
layer_state = {}
|
||||||
# next line mutates layer state
|
# next line mutates layer state
|
||||||
x, self_attn_weights = self.self_attn(
|
x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,)
|
||||||
query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask,
|
|
||||||
)
|
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
@ -386,11 +356,9 @@ class DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
x, encoder_attn_weights = self.encoder_attn(
|
x, encoder_attn_weights = self.encoder_attn(
|
||||||
query=x,
|
query=x,
|
||||||
key=encoder_hidden_states, # could be None
|
key=encoder_hidden_states,
|
||||||
value=encoder_hidden_states,
|
|
||||||
key_padding_mask=encoder_attn_mask,
|
key_padding_mask=encoder_attn_mask,
|
||||||
layer_state=layer_state, # mutates layer state
|
layer_state=layer_state, # mutates layer state
|
||||||
static_kv=True,
|
|
||||||
)
|
)
|
||||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
x = residual + x
|
x = residual + x
|
||||||
@ -527,19 +495,15 @@ class BartDecoder(nn.Module):
|
|||||||
return x, next_cache, all_hidden_states, list(all_self_attns)
|
return x, next_cache, all_hidden_states, list(all_self_attns)
|
||||||
|
|
||||||
|
|
||||||
def reorder_attn_buffer(input_buffer, new_order):
|
def _reorder_buffer(attn_cache, new_order):
|
||||||
"""Reorder buffered internal state (for incremental generation)."""
|
for k, input_buffer_k in attn_cache.items():
|
||||||
# input_buffer = self._get_input_buffer(incremental_state)
|
|
||||||
for k in input_buffer.keys():
|
|
||||||
input_buffer_k = input_buffer[k]
|
|
||||||
if input_buffer_k is not None:
|
if input_buffer_k is not None:
|
||||||
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
attn_cache[k] = input_buffer_k.index_select(0, new_order)
|
||||||
# incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
return attn_cache
|
||||||
return input_buffer
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Multi-headed attention from "Attention Is All You Need"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -551,7 +515,6 @@ class SelfAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
@ -572,42 +535,29 @@ class SelfAttention(nn.Module):
|
|||||||
self,
|
self,
|
||||||
query,
|
query,
|
||||||
key: Optional[Tensor],
|
key: Optional[Tensor],
|
||||||
value: Optional[Tensor],
|
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||||||
static_kv: bool = False,
|
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""Input shape: Time(SeqLen) x Batch x Channel
|
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||||||
|
static_kv = self.encoder_decoder_attention # type: bool
|
||||||
Args:
|
|
||||||
|
|
||||||
key_padding_mask (ByteTensor, optional): mask to exclude
|
|
||||||
keys that are pads, of shape `(batch, src_len)`, where
|
|
||||||
padding elements are indicated by 1s.
|
|
||||||
attn_mask (ByteTensor, optional): typically used to
|
|
||||||
implement causal attention, where the mask prevents the
|
|
||||||
attention from looking forward in time (default: None).
|
|
||||||
"""
|
|
||||||
tgt_len, bsz, embed_dim = query.size()
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
assert embed_dim == self.embed_dim
|
assert embed_dim == self.embed_dim
|
||||||
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
# get here for encoder decoder cause of static_kv
|
# get here for encoder decoder cause of static_kv
|
||||||
if layer_state is not None: # get the last k,v and mask for reuse
|
if layer_state is not None: # reuse k,v and encoder_padding_mask
|
||||||
saved_state = layer_state.get(self.cache_key, {})
|
saved_state = layer_state.get(self.cache_key, {})
|
||||||
if "prev_key" in saved_state:
|
if "prev_key" in saved_state:
|
||||||
# previous time steps are cached - no need to recompute key and value if they are static
|
# previous time steps are cached - no need to recompute key and value if they are static
|
||||||
if static_kv:
|
if static_kv:
|
||||||
assert self.encoder_decoder_attention
|
key = None
|
||||||
key = value = None
|
|
||||||
else:
|
else:
|
||||||
saved_state = None
|
saved_state = None
|
||||||
layer_state = {}
|
layer_state = {}
|
||||||
|
|
||||||
q = self.q_proj(query) * self.scaling
|
q = self.q_proj(query) * self.scaling
|
||||||
if self.encoder_decoder_attention:
|
if static_kv:
|
||||||
if key is None:
|
if key is None:
|
||||||
assert value is None
|
|
||||||
k = v = None
|
k = v = None
|
||||||
else:
|
else:
|
||||||
k = self.k_proj(key)
|
k = self.k_proj(key)
|
||||||
@ -624,7 +574,6 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
if saved_state is not None:
|
if saved_state is not None:
|
||||||
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
||||||
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
|
|
||||||
|
|
||||||
# Update cache
|
# Update cache
|
||||||
layer_state[self.cache_key] = {
|
layer_state[self.cache_key] = {
|
||||||
@ -636,7 +585,6 @@ class SelfAttention(nn.Module):
|
|||||||
assert k is not None
|
assert k is not None
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
|
||||||
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
@ -984,7 +932,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
for layer_past in decoder_cached_states:
|
for layer_past in decoder_cached_states:
|
||||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||||
layer_past_new = {
|
layer_past_new = {
|
||||||
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||||||
}
|
}
|
||||||
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||||
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||||
|
@ -330,6 +330,17 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
|
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
|
||||||
lm_model(input_ids, attention_mask=attention_mask)
|
lm_model(input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
def test_default_generate_kwargs(self):
|
||||||
|
config, input_ids, _ = self._get_config_and_data(output_past=True)
|
||||||
|
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
model.generate(input_ids)
|
||||||
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
|
||||||
|
def test_dummy_inputs(self):
|
||||||
|
config, *_ = self._get_config_and_data(output_past=True)
|
||||||
|
model = BartForConditionalGeneration(config).eval().to(torch_device)
|
||||||
|
model(**model.dummy_inputs)
|
||||||
|
|
||||||
def test_prepare_bart_decoder_inputs(self):
|
def test_prepare_bart_decoder_inputs(self):
|
||||||
config, *_ = self._get_config_and_data(output_past=False)
|
config, *_ = self._get_config_and_data(output_past=False)
|
||||||
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
|
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed
|
||||||
|
Loading…
Reference in New Issue
Block a user