Contrastive Search peak memory reduction (#24120)

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Benjamin Badger 2023-07-20 13:46:53 -04:00 committed by GitHub
parent aa1b09c5d1
commit caf5e369fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 31 deletions

View File

@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality. prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
> Parameters that define the output variables of `generate` > Parameters that define the output variables of `generate`
@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin):
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None) self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None) self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)
# Parameters that define the output variables of `generate` # Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1) self.num_return_sequences = kwargs.pop("num_return_sequences", 1)

View File

@ -1569,6 +1569,7 @@ class GenerationMixin:
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus, synced_gpus=synced_gpus,
streamer=streamer, streamer=streamer,
sequential=generation_config.low_memory,
**model_kwargs, **model_kwargs,
) )
@ -1832,6 +1833,7 @@ class GenerationMixin:
return_dict_in_generate: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False, synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None, streamer: Optional["BaseStreamer"] = None,
sequential: Optional[bool] = None,
**model_kwargs, **model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]: ) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r""" r"""
@ -1882,6 +1884,8 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*): streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
sequential (`bool`, *optional*):
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
model_kwargs: model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model. Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`. If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@ -1921,6 +1925,7 @@ class GenerationMixin:
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
sequential = sequential if sequential is not None else self.generation_config.low_memory
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
@ -1986,6 +1991,7 @@ class GenerationMixin:
last_hidden_states = outputs.decoder_hidden_states[-1] last_hidden_states = outputs.decoder_hidden_states[-1]
else: else:
last_hidden_states = outputs.hidden_states[-1] last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens # next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :] logit_for_next_step = outputs.logits[:, -1, :]
@ -1995,7 +2001,7 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True, standardize_cache_format=True,
) )
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search). # Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation( _, model_kwargs = self._expand_inputs_for_generation(
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
@ -2019,7 +2025,6 @@ class GenerationMixin:
# contrastive_search main logic start: # contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty # degeneration penalty
logit_for_next_step = logits_processor(input_ids, logit_for_next_step) logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
logit_for_next_step = logits_warper(input_ids, logit_for_next_step) logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
@ -2049,18 +2054,64 @@ class GenerationMixin:
items = [] items = []
# item is either the key or the value matrix # item is either the key or the value matrix
for item in layer: for item in layer:
if sequential:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0)) items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(items) new_key_values.append(items)
model_kwargs["past_key_values"] = new_key_values model_kwargs["past_key_values"] = new_key_values
# compute the candidate tokens by the language model and collects their hidden_states if sequential:
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs) all_outputs = {key: [] for key in outputs} # defined in first loop iteration
outputs = self( all_last_hstates, all_hstates, all_logits = [], [], []
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions for i in range(top_k):
) # compute the candidate tokens by the language model and collect their hidden_states
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
logits = outputs.logits[:, -1, :] outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
for key in all_outputs:
all_outputs[key].append(outputs[key])
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
full_hidden_states = outputs.decoder_hidden_states
else:
next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states
all_last_hstates.append(torch.squeeze(next_hidden, 0))
all_hstates.append(full_hidden_states)
all_logits.append(outputs.logits[:, -1, :])
# stack hidden states
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
final_full_hstates = [0 for i in range(len(full_hidden_states))]
for layer in range(len(full_hidden_states)):
final_full_hstates[layer] = torch.stack(
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
)
full_hidden_states = tuple(final_full_hstates)
# stack logits
logits = torch.cat(all_logits, dim=0)
else:
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
outputs = self(
**next_model_inputs,
return_dict=True,
output_hidden_states=True,
output_attentions=output_attentions,
)
# name is different for encoder-decoder and decoder-only models # name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1] next_hidden = outputs.decoder_hidden_states[-1]
@ -2068,6 +2119,9 @@ class GenerationMixin:
else: else:
next_hidden = outputs.hidden_states[-1] next_hidden = outputs.hidden_states[-1]
full_hidden_states = outputs.hidden_states full_hidden_states = outputs.hidden_states
logits = outputs.logits[:, -1, :]
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
@ -2089,7 +2143,22 @@ class GenerationMixin:
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :] layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
next_decoder_hidden_states += (layer,) next_decoder_hidden_states += (layer,)
# select the past_key_value # generate past_key_values cache of only the selected token
if sequential:
next_model_input = self.prepare_inputs_for_generation(
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
)
selected_outputs = self(
**next_model_input,
return_dict=True,
output_hidden_states=False,
output_attentions=False,
)
next_past_key_values = selected_outputs["past_key_values"]
else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
new_key_values = () new_key_values = ()
for layer in next_past_key_values: for layer in next_past_key_values:
items = () items = ()

View File

@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
for output in (output_contrastive, output_generate): for output in (output_contrastive, output_generate):
self._check_outputs(output, input_ids, model.config, use_cache=True) self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_contrastive_generate_low_memory(self):
# Check that choosing 'low_memory' does not change the model output
for model_class in self.all_generative_model_classes:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if any(
model_name in model_class.__name__.lower()
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
):
return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
# NOTE: contrastive search only works with cache on at the moment.
if not hasattr(config, "use_cache"):
return
config.use_cache = True
config.is_decoder = True
# test output equality of low versus high memory
model = model_class(config).to(torch_device).eval()
low_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=True,
max_length=max_length,
attention_mask=attention_mask,
)
high_output = model.generate(
input_ids,
top_k=4,
penalty_alpha=0.6,
low_memory=False,
max_length=max_length,
attention_mask=attention_mask,
)
self.assertListEqual(low_output.tolist(), high_output.tolist())
return
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%. @slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def test_assisted_decoding_matches_greedy_search(self): def test_assisted_decoding_matches_greedy_search(self):
# This test ensures that the assisted generation does not introduce output changes over greedy search. # This test ensures that the assisted generation does not introduce output changes over greedy search.