mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Contrastive Search peak memory reduction (#24120)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
aa1b09c5d1
commit
caf5e369fc
@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||||
prompt, usually at the expense of poorer quality.
|
prompt, usually at the expense of poorer quality.
|
||||||
|
low_memory (`bool`, *optional*):
|
||||||
|
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
|
||||||
|
|
||||||
|
|
||||||
> Parameters that define the output variables of `generate`
|
> Parameters that define the output variables of `generate`
|
||||||
|
|
||||||
@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
||||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||||
|
self.low_memory = kwargs.pop("low_memory", None)
|
||||||
|
|
||||||
# Parameters that define the output variables of `generate`
|
# Parameters that define the output variables of `generate`
|
||||||
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
||||||
|
@ -1569,6 +1569,7 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
streamer=streamer,
|
streamer=streamer,
|
||||||
|
sequential=generation_config.low_memory,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1832,6 +1833,7 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate: Optional[bool] = None,
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
synced_gpus: bool = False,
|
synced_gpus: bool = False,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
sequential: Optional[bool] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@ -1882,6 +1884,8 @@ class GenerationMixin:
|
|||||||
streamer (`BaseStreamer`, *optional*):
|
streamer (`BaseStreamer`, *optional*):
|
||||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
|
sequential (`bool`, *optional*):
|
||||||
|
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@ -1921,6 +1925,7 @@ class GenerationMixin:
|
|||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
@ -1986,6 +1991,7 @@ class GenerationMixin:
|
|||||||
last_hidden_states = outputs.decoder_hidden_states[-1]
|
last_hidden_states = outputs.decoder_hidden_states[-1]
|
||||||
else:
|
else:
|
||||||
last_hidden_states = outputs.hidden_states[-1]
|
last_hidden_states = outputs.hidden_states[-1]
|
||||||
|
|
||||||
# next logit for contrastive search to select top-k candidate tokens
|
# next logit for contrastive search to select top-k candidate tokens
|
||||||
logit_for_next_step = outputs.logits[:, -1, :]
|
logit_for_next_step = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
@ -1995,7 +2001,7 @@ class GenerationMixin:
|
|||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
standardize_cache_format=True,
|
standardize_cache_format=True,
|
||||||
)
|
)
|
||||||
|
if not sequential:
|
||||||
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
|
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
|
||||||
_, model_kwargs = self._expand_inputs_for_generation(
|
_, model_kwargs = self._expand_inputs_for_generation(
|
||||||
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
expand_size=top_k, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs
|
||||||
@ -2019,7 +2025,6 @@ class GenerationMixin:
|
|||||||
# contrastive_search main logic start:
|
# contrastive_search main logic start:
|
||||||
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
||||||
# degeneration penalty
|
# degeneration penalty
|
||||||
|
|
||||||
logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
|
logit_for_next_step = logits_processor(input_ids, logit_for_next_step)
|
||||||
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
|
logit_for_next_step = logits_warper(input_ids, logit_for_next_step)
|
||||||
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
|
next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)
|
||||||
@ -2049,18 +2054,64 @@ class GenerationMixin:
|
|||||||
items = []
|
items = []
|
||||||
# item is either the key or the value matrix
|
# item is either the key or the value matrix
|
||||||
for item in layer:
|
for item in layer:
|
||||||
|
if sequential:
|
||||||
|
items.append(item.repeat_interleave(1, dim=0))
|
||||||
|
else:
|
||||||
items.append(item.repeat_interleave(top_k, dim=0))
|
items.append(item.repeat_interleave(top_k, dim=0))
|
||||||
new_key_values.append(items)
|
new_key_values.append(items)
|
||||||
model_kwargs["past_key_values"] = new_key_values
|
model_kwargs["past_key_values"] = new_key_values
|
||||||
|
|
||||||
# compute the candidate tokens by the language model and collects their hidden_states
|
if sequential:
|
||||||
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
|
||||||
outputs = self(
|
all_last_hstates, all_hstates, all_logits = [], [], []
|
||||||
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
|
for i in range(top_k):
|
||||||
)
|
# compute the candidate tokens by the language model and collect their hidden_states
|
||||||
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
||||||
|
|
||||||
logits = outputs.logits[:, -1, :]
|
outputs = self(
|
||||||
|
**next_model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
for key in all_outputs:
|
||||||
|
all_outputs[key].append(outputs[key])
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
next_hidden = outputs.decoder_hidden_states[-1]
|
||||||
|
full_hidden_states = outputs.decoder_hidden_states
|
||||||
|
|
||||||
|
else:
|
||||||
|
next_hidden = outputs.hidden_states[-1]
|
||||||
|
full_hidden_states = outputs.hidden_states
|
||||||
|
|
||||||
|
all_last_hstates.append(torch.squeeze(next_hidden, 0))
|
||||||
|
all_hstates.append(full_hidden_states)
|
||||||
|
all_logits.append(outputs.logits[:, -1, :])
|
||||||
|
|
||||||
|
# stack hidden states
|
||||||
|
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
|
||||||
|
final_full_hstates = [0 for i in range(len(full_hidden_states))]
|
||||||
|
for layer in range(len(full_hidden_states)):
|
||||||
|
final_full_hstates[layer] = torch.stack(
|
||||||
|
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
|
||||||
|
)
|
||||||
|
full_hidden_states = tuple(final_full_hstates)
|
||||||
|
|
||||||
|
# stack logits
|
||||||
|
logits = torch.cat(all_logits, dim=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# compute the candidate tokens by the language model and collect their hidden_states
|
||||||
|
# assembles top_k_ids into batch of size k
|
||||||
|
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
||||||
|
|
||||||
|
outputs = self(
|
||||||
|
**next_model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
# name is different for encoder-decoder and decoder-only models
|
# name is different for encoder-decoder and decoder-only models
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
next_hidden = outputs.decoder_hidden_states[-1]
|
next_hidden = outputs.decoder_hidden_states[-1]
|
||||||
@ -2068,6 +2119,9 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
next_hidden = outputs.hidden_states[-1]
|
next_hidden = outputs.hidden_states[-1]
|
||||||
full_hidden_states = outputs.hidden_states
|
full_hidden_states = outputs.hidden_states
|
||||||
|
|
||||||
|
logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
||||||
|
|
||||||
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
|
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
|
||||||
@ -2089,7 +2143,22 @@ class GenerationMixin:
|
|||||||
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
|
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
|
||||||
next_decoder_hidden_states += (layer,)
|
next_decoder_hidden_states += (layer,)
|
||||||
|
|
||||||
# select the past_key_value
|
# generate past_key_values cache of only the selected token
|
||||||
|
if sequential:
|
||||||
|
next_model_input = self.prepare_inputs_for_generation(
|
||||||
|
top_k_ids[:, selected_idx].view(-1, 1), **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_outputs = self(
|
||||||
|
**next_model_input,
|
||||||
|
return_dict=True,
|
||||||
|
output_hidden_states=False,
|
||||||
|
output_attentions=False,
|
||||||
|
)
|
||||||
|
next_past_key_values = selected_outputs["past_key_values"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
||||||
new_key_values = ()
|
new_key_values = ()
|
||||||
for layer in next_past_key_values:
|
for layer in next_past_key_values:
|
||||||
items = ()
|
items = ()
|
||||||
|
@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
|
|||||||
for output in (output_contrastive, output_generate):
|
for output in (output_contrastive, output_generate):
|
||||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
|
def test_contrastive_generate_low_memory(self):
|
||||||
|
# Check that choosing 'low_memory' does not change the model output
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
|
||||||
|
if any(
|
||||||
|
model_name in model_class.__name__.lower()
|
||||||
|
for model_name in ["fsmt", "reformer", "gptbigcode", "speech2text"]
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
|
if not hasattr(config, "use_cache"):
|
||||||
|
return
|
||||||
|
|
||||||
|
config.use_cache = True
|
||||||
|
config.is_decoder = True
|
||||||
|
|
||||||
|
# test output equality of low versus high memory
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
|
low_output = model.generate(
|
||||||
|
input_ids,
|
||||||
|
top_k=4,
|
||||||
|
penalty_alpha=0.6,
|
||||||
|
low_memory=True,
|
||||||
|
max_length=max_length,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
high_output = model.generate(
|
||||||
|
input_ids,
|
||||||
|
top_k=4,
|
||||||
|
penalty_alpha=0.6,
|
||||||
|
low_memory=False,
|
||||||
|
max_length=max_length,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
|
@slow # TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||||
|
Loading…
Reference in New Issue
Block a user