[Bart] Replace config.output_past with use_cache kwarg (#3632)

This commit is contained in:
Sam Shleifer 2020-04-07 19:08:26 -04:00 committed by GitHub
parent e344e3d402
commit 715aa5b135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 26 deletions

View File

@ -20,7 +20,7 @@ def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device)
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large")
max_length = 140

View File

@ -56,7 +56,6 @@ class BartConfig(PretrainedConfig):
max_position_embeddings=1024,
init_std=0.02,
classifier_dropout=0.0,
output_past=False,
num_labels=3,
is_encoder_decoder=True,
pad_token_id=1,
@ -72,7 +71,6 @@ class BartConfig(PretrainedConfig):
"""
super().__init__(
num_labels=num_labels,
output_past=output_past,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,

View File

@ -388,7 +388,6 @@ class BartDecoder(nn.Module):
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__()
self.output_past = config.output_past
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.dropout = config.dropout
@ -412,7 +411,7 @@ class BartDecoder(nn.Module):
decoder_padding_mask,
decoder_causal_mask,
decoder_cached_states=None,
generation_mode=False,
use_cache=False,
**unused
):
"""
@ -438,9 +437,9 @@ class BartDecoder(nn.Module):
encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
positions = self.embed_positions(input_ids, use_cache=use_cache)
if generation_mode:
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
assert input_ids.ne(self.padding_idx).any()
@ -476,7 +475,7 @@ class BartDecoder(nn.Module):
causal_mask=decoder_causal_mask,
)
if self.output_past:
if use_cache:
next_decoder_cache.append(layer_past.copy())
if self.output_hidden_states:
all_hidden_states += (x,)
@ -488,7 +487,7 @@ class BartDecoder(nn.Module):
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if self.output_past:
if use_cache:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
@ -710,9 +709,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
num_embeddings += padding_idx + 1 # WHY?
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
def forward(self, input, generation_mode=False):
def forward(self, input, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
if generation_mode: # the position is our current step in the decoded sequence
if use_cache: # the position is our current step in the decoded sequence
pos = int(self.padding_idx + input.size(1))
positions = input.data.new(1, 1).fill_(pos)
else:
@ -772,11 +771,11 @@ class BartModel(PretrainedBartModel):
encoder_outputs=None, # type: Tuple
decoder_attention_mask=None,
decoder_cached_states=None,
generation_mode=False,
use_cache=False,
):
# make masks if user doesn't supply
if not generation_mode:
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
self.config,
input_ids,
@ -799,7 +798,7 @@ class BartModel(PretrainedBartModel):
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
use_cache=use_cache,
)
# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
@ -841,7 +840,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_attention_mask=None,
decoder_cached_states=None,
lm_labels=None,
generation_mode=False,
use_cache=False,
**unused
):
r"""
@ -892,7 +891,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
use_cache=use_cache,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
@ -918,7 +917,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"generation_mode": True,
"use_cache": True, # change this to avoid caching (presumably for debugging)
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
@ -951,6 +950,10 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
def _do_output_past(self, *args, **kwargs):
""" We should always use the cache in generate."""
return True
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,

File diff suppressed because one or more lines are too long