Contrastive Search peak memory reduction (#24120)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Benjamin Badger 2023-07-20 13:46:53 -04:00 committed by GitHub
parent aa1b09c5d1
commit caf5e369fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 31 deletions

View File

@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
> Parameters that define the output variables of `generate`
@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin):
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)
# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

View File

@ -1569,6 +1569,7 @@ class GenerationMixin:
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
sequential=generation_config.low_memory,
**model_kwargs,
)
@ -1832,6 +1833,7 @@ class GenerationMixin:
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
sequential: Optional[bool] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r"""
@ -1882,6 +1884,8 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
sequential (`bool`, *optional*):
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@ -1921,6 +1925,7 @@ class GenerationMixin:
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
sequential = sequential if sequential is not None else self.generation_config.low_memory
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
@ -1986,6 +1991,7 @@ class GenerationMixin:
last_hidden_states = outputs.decoder_hidden_states[-1]
else:
last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :]
@ -1995,11 +2001,11 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
)
past_key_values = model_kwargs.get("past_key_values")
if past_key_values is None:
@ -2019,7 +2025,6 @@ class GenerationMixin:
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
@ -2049,25 +2054,74 @@ class GenerationMixin:
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
if sequential:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items)
model_kwargs["past_key_values"] = new_key_values
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
if sequential:
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
all_last_hstates, all_hstates, all_logits = [], [], []
for i in range(top_k):
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
for key in all_outputs:
all_outputs[key].append(outputs[key])
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
all_last_hstates.append(torch.squeeze(next_hidden, 0))
all_hstates.append(full_hidden_states)
all_logits.append(outputs.logits[:, -1, :])
# stack hidden states
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
final_full_hstates = [0 for i in range(len(full_hidden_states))]
for layer in range(len(full_hidden_states)):
final_full_hstates[layer] = torch.stack(
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
)
full_hidden_states = tuple(final_full_hstates)
# stack logits
logits = torch.cat(all_logits, dim=0)
logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
logits = outputs.logits[:, -1, :]
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
@ -2089,17 +2143,32 @@ class GenerationMixin:
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
next_decoder_hidden_states += (layer,)
# select the past_key_value
new_key_values = ()
for layer in next_past_key_values:
items = ()
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values
# generate past_key_values cache of only the selected token
if sequential:
next_model_input = self.prepare_inputs_for_generation(
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
)
selected_outputs = self(
**next_model_input,
return_dict=True,
output_hidden_states=False,
output_attentions=False,
)
next_past_key_values = selected_outputs["past_key_values"]
else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
new_key_values = ()
for layer in next_past_key_values:
items = ()
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]

View File

@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
low_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=True,
max_length=max_length,
attention_mask=attention_mask,
)
high_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=False,
max_length=max_length,
attention_mask=attention_mask,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
return
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search.