[BART] cleanup: remove redundant kwargs, improve docstrings (#3319)

This commit is contained in:
Sam Shleifer 2020-03-19 11:16:51 -04:00 committed by GitHub
parent cd21d8bc00
commit ad7233fc01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 80 deletions

View File

@ -56,7 +56,7 @@ BART_GENERATION_EXAMPLE = r"""
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary # Generate Summary
summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5) summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
""" """
@ -84,8 +84,9 @@ LARGE_NEGATIVE = -1e8
def _prepare_bart_decoder_inputs( def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None, config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
): ):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if """Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks. none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
Note: this is not called during generation
""" """
pad_token_id = config.pad_token_id pad_token_id = config.pad_token_id
need_causal_mask = not config.output_past need_causal_mask = not config.output_past
@ -114,8 +115,6 @@ class PretrainedBartModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
# called init_bert_params in fairseq
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std) module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None: if module.bias is not None:
@ -127,16 +126,9 @@ class PretrainedBartModel(PreTrainedModel):
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = 1 pad_token = self.config.pad_token_id
input_ids = torch.Tensor( input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
[ decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,)
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2],
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 2, pad_token],
]
).long()
decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attn_mask=None
)
dummy_inputs = { dummy_inputs = {
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": input_ids.ne(pad_token), "attention_mask": input_ids.ne(pad_token),
@ -149,7 +141,7 @@ class PretrainedBartModel(PreTrainedModel):
def _make_linear_from_emb(emb): def _make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False) lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data # .T lin_layer.weight.data = emb.weight.data
return lin_layer return lin_layer
@ -160,8 +152,8 @@ def _check_shapes(shape_1, shape2):
def _combine_masks(key_padding_mask, causal_lm_mask, targ_size): def _combine_masks(key_padding_mask, causal_lm_mask, targ_size):
# targ_size = (bsz, tgt_len, src_len) """Make one mask of shape (bsz, 1, tgt_len, src_len) """
a = torch.zeros(targ_size) a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len)
b = torch.zeros(targ_size) b = torch.zeros(targ_size)
if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size
_check_shapes(key_padding_mask.shape, targ_size[:2]) _check_shapes(key_padding_mask.shape, targ_size[:2])
@ -223,7 +215,7 @@ class EncoderLayer(nn.Module):
encoded output of shape `(seq_len, batch, embed_dim)` encoded output of shape `(seq_len, batch, embed_dim)`
""" """
residual = x residual = x
x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,) x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
@ -266,7 +258,7 @@ class BartEncoder(nn.Module):
self.layernorm_embedding = LayerNorm(embed_dim) self.layernorm_embedding = LayerNorm(embed_dim)
def forward( def forward(
self, input_ids=None, attention_mask=None, self, input_ids, attention_mask=None,
): ):
""" """
Args: Args:
@ -274,21 +266,19 @@ class BartEncoder(nn.Module):
`(batch, src_len)` `(batch, src_len)`
attention_mask (torch.LongTensor): indicating which indices are padding tokens. attention_mask (torch.LongTensor): indicating which indices are padding tokens.
Returns: Returns:
namedtuple: Tuple comprised of:
- **x** (Tensor): the last encoder layer's output of - **x** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)` shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate - **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`. hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True. Only populated if *self.output_hidden_states:* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer. - **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout. During training might not be of length n_layers because of layer dropout.
""" """
# check attention mask and invert # check attention mask and invert
if attention_mask is not None: if attention_mask is not None:
assert attention_mask.dim() == 2 assert attention_mask.dim() == 2
attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE
attention_mask = (1.0 - attention_mask.long()) * -10000.0
assert attention_mask.max() <= 0 assert attention_mask.max() <= 0
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids) embed_pos = self.embed_positions(input_ids)
@ -300,10 +290,7 @@ class BartEncoder(nn.Module):
x = x.transpose(0, 1) x = x.transpose(0, 1)
encoder_states, all_attentions = [], [] encoder_states, all_attentions = [], []
# encoder layers
for encoder_layer in self.layers: for encoder_layer in self.layers:
if self.output_hidden_states: if self.output_hidden_states:
encoder_states.append(x) encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@ -320,7 +307,6 @@ class BartEncoder(nn.Module):
encoder_states.append(x) encoder_states.append(x)
encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states]
return x, encoder_states, all_attentions return x, encoder_states, all_attentions
@ -356,28 +342,12 @@ class DecoderLayer(nn.Module):
attention_mask=None, attention_mask=None,
need_attn_weights=False, need_attn_weights=False,
): ):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attn_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x residual = x
y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it
if layer_state is None: if layer_state is None:
layer_state = {} layer_state = {}
# next line mutates layer state # next line mutates layer state
x, self_attn_weights = self.self_attn( x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,)
query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.self_attn_layer_norm(x) x = self.self_attn_layer_norm(x)
@ -386,11 +356,9 @@ class DecoderLayer(nn.Module):
x, encoder_attn_weights = self.encoder_attn( x, encoder_attn_weights = self.encoder_attn(
query=x, query=x,
key=encoder_hidden_states, # could be None key=encoder_hidden_states,
value=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state layer_state=layer_state, # mutates layer state
static_kv=True,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
@ -527,19 +495,15 @@ class BartDecoder(nn.Module):
return x, next_cache, all_hidden_states, list(all_self_attns) return x, next_cache, all_hidden_states, list(all_self_attns)
def reorder_attn_buffer(input_buffer, new_order): def _reorder_buffer(attn_cache, new_order):
"""Reorder buffered internal state (for incremental generation).""" for k, input_buffer_k in attn_cache.items():
# input_buffer = self._get_input_buffer(incremental_state)
for k in input_buffer.keys():
input_buffer_k = input_buffer[k]
if input_buffer_k is not None: if input_buffer_k is not None:
input_buffer[k] = input_buffer_k.index_select(0, new_order) attn_cache[k] = input_buffer_k.index_select(0, new_order)
# incremental_state = self._set_input_buffer(incremental_state, input_buffer) return attn_cache
return input_buffer
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
"""Multi-headed attention from "Attention Is All You Need""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__( def __init__(
self, self,
@ -551,7 +515,6 @@ class SelfAttention(nn.Module):
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
@ -572,42 +535,29 @@ class SelfAttention(nn.Module):
self, self,
query, query,
key: Optional[Tensor], key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time(SeqLen) x Batch x Channel """Input shape: Time(SeqLen) x Batch x Channel"""
static_kv = self.encoder_decoder_attention # type: bool
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
tgt_len, bsz, embed_dim = query.size() tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim] assert list(query.size()) == [tgt_len, bsz, embed_dim]
# get here for encoder decoder cause of static_kv # get here for encoder decoder cause of static_kv
if layer_state is not None: # get the last k,v and mask for reuse if layer_state is not None: # reuse k,v and encoder_padding_mask
saved_state = layer_state.get(self.cache_key, {}) saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state: if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static # previous time steps are cached - no need to recompute key and value if they are static
if static_kv: if static_kv:
assert self.encoder_decoder_attention key = None
key = value = None
else: else:
saved_state = None saved_state = None
layer_state = {} layer_state = {}
q = self.q_proj(query) * self.scaling q = self.q_proj(query) * self.scaling
if self.encoder_decoder_attention: if static_kv:
if key is None: if key is None:
assert value is None
k = v = None k = v = None
else: else:
k = self.k_proj(key) k = self.k_proj(key)
@ -624,7 +574,6 @@ class SelfAttention(nn.Module):
if saved_state is not None: if saved_state is not None:
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
# assert self.cache_key != 'encoder_decoder' or key_padding_mask is None
# Update cache # Update cache
layer_state[self.cache_key] = { layer_state[self.cache_key] = {
@ -636,7 +585,6 @@ class SelfAttention(nn.Module):
assert k is not None assert k is not None
src_len = k.size(1) src_len = k.size(1)
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None: if attn_mask is not None:
@ -984,7 +932,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
for layer_past in decoder_cached_states: for layer_past in decoder_cached_states:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn # get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = { layer_past_new = {
attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
} }
# reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] # reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
# reordered_layer_past = torch.cat(reordered_layer_past, dim=1) # reordered_layer_past = torch.cat(reordered_layer_past, dim=1)

View File

@ -330,6 +330,17 @@ class BartHeadTests(unittest.TestCase):
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half() lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
lm_model(input_ids, attention_mask=attention_mask) lm_model(input_ids, attention_mask=attention_mask)
def test_default_generate_kwargs(self):
config, input_ids, _ = self._get_config_and_data(output_past=True)
model = BartForConditionalGeneration(config).eval().to(torch_device)
model.generate(input_ids)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_dummy_inputs(self):
config, *_ = self._get_config_and_data(output_past=True)
model = BartForConditionalGeneration(config).eval().to(torch_device)
model(**model.dummy_inputs)
def test_prepare_bart_decoder_inputs(self): def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False) config, *_ = self._get_config_and_data(output_past=False)
input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed