diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 5976d49074e..ee8e7c54cc1 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -56,7 +56,7 @@ BART_GENERATION_EXAMPLE = r""" ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') # Generate Summary - summary_ids = model.generate(inputs['input_ids'], attention_mask=inputs['attention_mask'], num_beams=4, max_length=5) + summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) """ @@ -84,8 +84,9 @@ LARGE_NEGATIVE = -1e8 def _prepare_bart_decoder_inputs( config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None, ): - """Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if + """Prepare masks that ignore padding tokens in the decoder and a causal lm mask for the decoder if none are provided. This mimics the default behavior in fairseq. To override it pass in masks. + Note: this is not called during generation """ pad_token_id = config.pad_token_id need_causal_mask = not config.output_past @@ -114,8 +115,6 @@ class PretrainedBartModel(PreTrainedModel): def _init_weights(self, module): std = self.config.init_std - - # called init_bert_params in fairseq if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: @@ -127,16 +126,9 @@ class PretrainedBartModel(PreTrainedModel): @property def dummy_inputs(self): - pad_token = 1 - input_ids = torch.Tensor( - [ - [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], - [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 2, pad_token], - ] - ).long() - decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs( - self.config, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attn_mask=None - ) + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]]) + decoder_input_ids, decoder_attn_mask = _prepare_bart_decoder_inputs(self.config, input_ids,) dummy_inputs = { "decoder_input_ids": decoder_input_ids, "attention_mask": input_ids.ne(pad_token), @@ -149,7 +141,7 @@ class PretrainedBartModel(PreTrainedModel): def _make_linear_from_emb(emb): vocab_size, emb_size = emb.weight.shape lin_layer = nn.Linear(vocab_size, emb_size, bias=False) - lin_layer.weight.data = emb.weight.data # .T + lin_layer.weight.data = emb.weight.data return lin_layer @@ -160,8 +152,8 @@ def _check_shapes(shape_1, shape2): def _combine_masks(key_padding_mask, causal_lm_mask, targ_size): - # targ_size = (bsz, tgt_len, src_len) - a = torch.zeros(targ_size) + """Make one mask of shape (bsz, 1, tgt_len, src_len) """ + a = torch.zeros(targ_size) # targ_size is(bsz, tgt_len, src_len) b = torch.zeros(targ_size) if key_padding_mask is not None: # (bsz, tgt_len) -> targ_size _check_shapes(key_padding_mask.shape, targ_size[:2]) @@ -223,7 +215,7 @@ class EncoderLayer(nn.Module): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x - x, attn_weights = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask,) + x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) @@ -266,7 +258,7 @@ class BartEncoder(nn.Module): self.layernorm_embedding = LayerNorm(embed_dim) def forward( - self, input_ids=None, attention_mask=None, + self, input_ids, attention_mask=None, ): """ Args: @@ -274,21 +266,19 @@ class BartEncoder(nn.Module): `(batch, src_len)` attention_mask (torch.LongTensor): indicating which indices are padding tokens. Returns: - namedtuple: + Tuple comprised of: - **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. - Only populated if *return_all_hiddens* is True. + Only populated if *self.output_hidden_states:* is True. - **all_attentions** (List[Tensor]): Attention weights for each layer. During training might not be of length n_layers because of layer dropout. """ # check attention mask and invert if attention_mask is not None: assert attention_mask.dim() == 2 - - attention_mask = (1.0 - attention_mask.long()) * -10000.0 + attention_mask = (1.0 - attention_mask.long()) * LARGE_NEGATIVE assert attention_mask.max() <= 0 inputs_embeds = self.embed_tokens(input_ids) embed_pos = self.embed_positions(input_ids) @@ -300,10 +290,7 @@ class BartEncoder(nn.Module): x = x.transpose(0, 1) encoder_states, all_attentions = [], [] - - # encoder layers for encoder_layer in self.layers: - if self.output_hidden_states: encoder_states.append(x) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -320,7 +307,6 @@ class BartEncoder(nn.Module): encoder_states.append(x) encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] - return x, encoder_states, all_attentions @@ -356,28 +342,12 @@ class DecoderLayer(nn.Module): attention_mask=None, need_attn_weights=False, ): - """ - Args: - x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_attn_mask (ByteTensor, optional): binary - ByteTensor of shape `(batch, src_len)` where padding - elements are indicated by ``1``. - need_attn_weights (bool, optional): return attention weights - for each head (default: return average over heads). - - Returns: - encoded output of shape `(seq_len, batch, embed_dim)` - """ - residual = x - y = x # TODO(SS): figure out why fairseq did this, then hopefully delete it if layer_state is None: layer_state = {} # next line mutates layer state - x, self_attn_weights = self.self_attn( - query=x, key=y, value=y, layer_state=layer_state, attn_mask=attention_mask, - ) + x, self_attn_weights = self.self_attn(query=x, key=x, layer_state=layer_state, attn_mask=attention_mask,) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) @@ -386,11 +356,9 @@ class DecoderLayer(nn.Module): x, encoder_attn_weights = self.encoder_attn( query=x, - key=encoder_hidden_states, # could be None - value=encoder_hidden_states, + key=encoder_hidden_states, key_padding_mask=encoder_attn_mask, layer_state=layer_state, # mutates layer state - static_kv=True, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -527,19 +495,15 @@ class BartDecoder(nn.Module): return x, next_cache, all_hidden_states, list(all_self_attns) -def reorder_attn_buffer(input_buffer, new_order): - """Reorder buffered internal state (for incremental generation).""" - # input_buffer = self._get_input_buffer(incremental_state) - for k in input_buffer.keys(): - input_buffer_k = input_buffer[k] +def _reorder_buffer(attn_cache, new_order): + for k, input_buffer_k in attn_cache.items(): if input_buffer_k is not None: - input_buffer[k] = input_buffer_k.index_select(0, new_order) - # incremental_state = self._set_input_buffer(incremental_state, input_buffer) - return input_buffer + attn_cache[k] = input_buffer_k.index_select(0, new_order) + return attn_cache class SelfAttention(nn.Module): - """Multi-headed attention from "Attention Is All You Need""" + """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, @@ -551,7 +515,6 @@ class SelfAttention(nn.Module): ): super().__init__() self.embed_dim = embed_dim - self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads @@ -572,42 +535,29 @@ class SelfAttention(nn.Module): self, query, key: Optional[Tensor], - value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - layer_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - static_kv: bool = False, + layer_state: Optional[Dict[str, Optional[Tensor]]] = None, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: - """Input shape: Time(SeqLen) x Batch x Channel - - Args: - - key_padding_mask (ByteTensor, optional): mask to exclude - keys that are pads, of shape `(batch, src_len)`, where - padding elements are indicated by 1s. - attn_mask (ByteTensor, optional): typically used to - implement causal attention, where the mask prevents the - attention from looking forward in time (default: None). - """ + """Input shape: Time(SeqLen) x Batch x Channel""" + static_kv = self.encoder_decoder_attention # type: bool tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] # get here for encoder decoder cause of static_kv - if layer_state is not None: # get the last k,v and mask for reuse + if layer_state is not None: # reuse k,v and encoder_padding_mask saved_state = layer_state.get(self.cache_key, {}) if "prev_key" in saved_state: # previous time steps are cached - no need to recompute key and value if they are static if static_kv: - assert self.encoder_decoder_attention - key = value = None + key = None else: saved_state = None layer_state = {} q = self.q_proj(query) * self.scaling - if self.encoder_decoder_attention: + if static_kv: if key is None: - assert value is None k = v = None else: k = self.k_proj(key) @@ -624,7 +574,6 @@ class SelfAttention(nn.Module): if saved_state is not None: k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) - # assert self.cache_key != 'encoder_decoder' or key_padding_mask is None # Update cache layer_state[self.cache_key] = { @@ -636,7 +585,6 @@ class SelfAttention(nn.Module): assert k is not None src_len = k.size(1) attn_weights = torch.bmm(q, k.transpose(1, 2)) - assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: @@ -984,7 +932,7 @@ class BartForConditionalGeneration(PretrainedBartModel): for layer_past in decoder_cached_states: # get the correct batch idx from decoder layer's batch dim for cross and self-attn layer_past_new = { - attn_key: reorder_attn_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() + attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() } # reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] # reordered_layer_past = torch.cat(reordered_layer_past, dim=1) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 77aed9eb6cf..f08028c8d39 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -330,6 +330,17 @@ class BartHeadTests(unittest.TestCase): lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half() lm_model(input_ids, attention_mask=attention_mask) + def test_default_generate_kwargs(self): + config, input_ids, _ = self._get_config_and_data(output_past=True) + model = BartForConditionalGeneration(config).eval().to(torch_device) + model.generate(input_ids) + model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + def test_dummy_inputs(self): + config, *_ = self._get_config_and_data(output_past=True) + model = BartForConditionalGeneration(config).eval().to(torch_device) + model(**model.dummy_inputs) + def test_prepare_bart_decoder_inputs(self): config, *_ = self._get_config_and_data(output_past=False) input_ids = _long_tensor(([4, 4, 2])) # only used for .device if decoder_input_ids is passed