mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-19 20:48:22 +06:00
Generate: Fix modern llm generate
calls with synced_gpus
(#34095)
This commit is contained in:
parent
617b21273a
commit
37ea04013b
@ -379,9 +379,10 @@ class GenerationMixin:
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case
|
||||
if past_key_values is not None:
|
||||
model_inputs["past_key_values"] = past_key_values
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 or Exception 3
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
@ -2609,8 +2610,14 @@ class GenerationMixin:
|
||||
outputs.hidden_states[candidate_premature_layer][:, -1, :]
|
||||
).to(final_logits.device)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
next_token_logits = _dola_select_contrast(
|
||||
candidate_premature_layers, candidate_premature_logits, final_logits
|
||||
@ -2652,11 +2659,6 @@ class GenerationMixin:
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
@ -3016,8 +3018,14 @@ class GenerationMixin:
|
||||
)
|
||||
# contrastive_search main logic end
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if has_eos_stopping_criteria:
|
||||
@ -3027,11 +3035,6 @@ class GenerationMixin:
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# stop when each sentence is finished
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
@ -3168,8 +3171,14 @@ class GenerationMixin:
|
||||
# forward pass to get next token
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
@ -3214,11 +3223,6 @@ class GenerationMixin:
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
@ -3415,9 +3419,15 @@ class GenerationMixin:
|
||||
else: # Unchanged original behavior
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
@ -3491,12 +3501,6 @@ class GenerationMixin:
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
||||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
||||
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
||||
@ -3670,9 +3674,15 @@ class GenerationMixin:
|
||||
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
if output_scores:
|
||||
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
||||
@ -3782,12 +3792,6 @@ class GenerationMixin:
|
||||
|
||||
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
||||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
||||
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
||||
@ -3948,9 +3952,15 @@ class GenerationMixin:
|
||||
|
||||
outputs = self(**model_inputs, return_dict=True)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
@ -4018,11 +4028,6 @@ class GenerationMixin:
|
||||
beam_idx = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
)
|
||||
|
||||
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
||||
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
||||
@ -4162,17 +4167,8 @@ class GenerationMixin:
|
||||
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
|
||||
# This is needed if return_dict_in_generate is True
|
||||
start_from_empty_dynamic_cache = False
|
||||
past_key_values = model_kwargs.get("past_key_values", None)
|
||||
if isinstance(past_key_values, DynamicCache) or (
|
||||
isinstance(past_key_values, EncoderDecoderCache)
|
||||
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
||||
):
|
||||
if past_key_values.get_seq_length() == 0:
|
||||
start_from_empty_dynamic_cache = True
|
||||
|
||||
this_peer_finished = False
|
||||
is_first_iteration = True # to preserve the same API in the output as other generation methods
|
||||
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
@ -4271,34 +4267,36 @@ class GenerationMixin:
|
||||
# 5. Update the candidate generation strategy if needed
|
||||
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)
|
||||
|
||||
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
num_new_tokens=n_matches + 1,
|
||||
)
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
continue
|
||||
|
||||
# Store scores, attentions and hidden_states when required
|
||||
# Assistant: modified to append one tuple element per token, as in the other generation methods.
|
||||
if return_dict_in_generate:
|
||||
newly_added_length = n_matches + 1
|
||||
if output_scores:
|
||||
scores += tuple(new_logits[:, i, :] for i in range(n_matches + 1))
|
||||
scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
|
||||
if output_logits:
|
||||
raw_logits += (next_token_logits,)
|
||||
|
||||
if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
|
||||
added_len = new_cur_len
|
||||
# set it to false for other iterations
|
||||
start_from_empty_dynamic_cache = False
|
||||
else:
|
||||
added_len = n_matches + 1
|
||||
raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))
|
||||
|
||||
newly_added_length = new_cur_len if is_first_iteration else newly_added_length
|
||||
if output_attentions:
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions = _split_model_outputs(
|
||||
cross_attentions, outputs.cross_attentions, cur_len, added_len
|
||||
cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
|
||||
)
|
||||
decoder_attentions = _split_model_outputs(
|
||||
decoder_attentions,
|
||||
outputs.decoder_attentions,
|
||||
cur_len,
|
||||
added_len,
|
||||
newly_added_length,
|
||||
is_decoder_attention=True,
|
||||
)
|
||||
else:
|
||||
@ -4306,28 +4304,22 @@ class GenerationMixin:
|
||||
decoder_attentions,
|
||||
outputs.attentions,
|
||||
cur_len,
|
||||
added_len,
|
||||
newly_added_length,
|
||||
is_decoder_attention=True,
|
||||
)
|
||||
if output_hidden_states:
|
||||
if self.config.is_encoder_decoder:
|
||||
decoder_hidden_states = _split_model_outputs(
|
||||
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, added_len
|
||||
decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
|
||||
)
|
||||
else:
|
||||
decoder_hidden_states = _split_model_outputs(
|
||||
decoder_hidden_states, outputs.hidden_states, cur_len, added_len
|
||||
decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
|
||||
)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
num_new_tokens=n_matches + 1,
|
||||
)
|
||||
|
||||
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
||||
this_peer_finished = unfinished_sequences.max() == 0
|
||||
is_first_iteration = False
|
||||
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
|
Loading…
Reference in New Issue
Block a user