mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Bart] Replace config.output_past with use_cache kwarg (#3632)
This commit is contained in:
parent
e344e3d402
commit
715aa5b135
@ -20,7 +20,7 @@ def generate_summaries(
|
||||
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
|
||||
):
|
||||
fout = Path(out_file).open("w")
|
||||
model = BartForConditionalGeneration.from_pretrained(model_name, output_past=True,).to(device)
|
||||
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
|
||||
tokenizer = BartTokenizer.from_pretrained("bart-large")
|
||||
|
||||
max_length = 140
|
||||
|
@ -56,7 +56,6 @@ class BartConfig(PretrainedConfig):
|
||||
max_position_embeddings=1024,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
output_past=False,
|
||||
num_labels=3,
|
||||
is_encoder_decoder=True,
|
||||
pad_token_id=1,
|
||||
@ -72,7 +71,6 @@ class BartConfig(PretrainedConfig):
|
||||
"""
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
output_past=output_past,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
|
@ -388,7 +388,6 @@ class BartDecoder(nn.Module):
|
||||
|
||||
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
|
||||
super().__init__()
|
||||
self.output_past = config.output_past
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.dropout = config.dropout
|
||||
@ -412,7 +411,7 @@ class BartDecoder(nn.Module):
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
decoder_cached_states=None,
|
||||
generation_mode=False,
|
||||
use_cache=False,
|
||||
**unused
|
||||
):
|
||||
"""
|
||||
@ -438,9 +437,9 @@ class BartDecoder(nn.Module):
|
||||
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||||
|
||||
# embed positions
|
||||
positions = self.embed_positions(input_ids, generation_mode=generation_mode)
|
||||
positions = self.embed_positions(input_ids, use_cache=use_cache)
|
||||
|
||||
if generation_mode:
|
||||
if use_cache:
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
assert input_ids.ne(self.padding_idx).any()
|
||||
@ -476,7 +475,7 @@ class BartDecoder(nn.Module):
|
||||
causal_mask=decoder_causal_mask,
|
||||
)
|
||||
|
||||
if self.output_past:
|
||||
if use_cache:
|
||||
next_decoder_cache.append(layer_past.copy())
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states += (x,)
|
||||
@ -488,7 +487,7 @@ class BartDecoder(nn.Module):
|
||||
x = x.transpose(0, 1)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||
|
||||
if self.output_past:
|
||||
if use_cache:
|
||||
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
|
||||
else:
|
||||
next_cache = None
|
||||
@ -710,9 +709,9 @@ class LearnedPositionalEmbedding(nn.Embedding):
|
||||
num_embeddings += padding_idx + 1 # WHY?
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
|
||||
def forward(self, input, generation_mode=False):
|
||||
def forward(self, input, use_cache=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
if generation_mode: # the position is our current step in the decoded sequence
|
||||
if use_cache: # the position is our current step in the decoded sequence
|
||||
pos = int(self.padding_idx + input.size(1))
|
||||
positions = input.data.new(1, 1).fill_(pos)
|
||||
else:
|
||||
@ -772,11 +771,11 @@ class BartModel(PretrainedBartModel):
|
||||
encoder_outputs=None, # type: Tuple
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
generation_mode=False,
|
||||
use_cache=False,
|
||||
):
|
||||
|
||||
# make masks if user doesn't supply
|
||||
if not generation_mode:
|
||||
if not use_cache:
|
||||
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_bart_decoder_inputs(
|
||||
self.config,
|
||||
input_ids,
|
||||
@ -799,7 +798,7 @@ class BartModel(PretrainedBartModel):
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask=causal_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
# Attention and hidden_states will be [] or None if they aren't needed
|
||||
decoder_outputs = _filter_out_falsey_values(decoder_outputs) # type: tuple
|
||||
@ -841,7 +840,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_attention_mask=None,
|
||||
decoder_cached_states=None,
|
||||
lm_labels=None,
|
||||
generation_mode=False,
|
||||
use_cache=False,
|
||||
**unused
|
||||
):
|
||||
r"""
|
||||
@ -892,7 +891,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
@ -918,7 +917,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"generation_mode": True,
|
||||
"use_cache": True, # change this to avoid caching (presumably for debugging)
|
||||
}
|
||||
|
||||
def prepare_scores_for_generation(self, scores, cur_len, max_length):
|
||||
@ -951,6 +950,10 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
def get_output_embeddings(self):
|
||||
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
||||
|
||||
def _do_output_past(self, *args, **kwargs):
|
||||
""" We should always use the cache in generate."""
|
||||
return True
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user