Generate: move logits to same device as input_ids (#34076)

tmp commit
This commit is contained in:
Joao Gante 2024-10-15 13:32:09 +01:00 committed by GitHub
parent 5ee9e786d1
commit d314ce70bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2623,6 +2623,7 @@ class GenerationMixin:
next_token_logits = _dola_select_contrast(
candidate_premature_layers, candidate_premature_logits, final_logits
)
next_token_logits = next_token_logits.to(input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
@ -2794,6 +2795,7 @@ class GenerationMixin:
# (the clone itself is always small)
# .float() is needed to retain precision for later logits manipulations
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
logit_for_next_step = logit_for_next_step.to(input_ids.device)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
@ -2988,6 +2990,7 @@ class GenerationMixin:
next_past_key_values = tuple(new_key_values)
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
logit_for_next_step = logit_for_next_step.to(input_ids.device)
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder:
@ -3184,6 +3187,7 @@ class GenerationMixin:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits.clone()[:, -1, :].float()
next_token_logits = next_token_logits.to(input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
@ -3434,6 +3438,7 @@ class GenerationMixin:
# (the clone itself is always small)
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -3691,6 +3696,7 @@ class GenerationMixin:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
raw_logit_score = outputs.logits[:, -1, :].clone()
raw_logit_score = raw_logit_score.to(input_ids.device)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
@ -3710,6 +3716,7 @@ class GenerationMixin:
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
@ -3967,6 +3974,7 @@ class GenerationMixin:
# (the clone itself is always small)
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].clone().float()
next_token_logits = next_token_logits.to(input_ids.device)
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
@ -4215,6 +4223,7 @@ class GenerationMixin:
# 2.3. Process the new logits
# .float() is needed to retain precision for later logits manipulations
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
new_logits = new_logits.to(input_ids.device)
next_token_logits = new_logits.clone()
if len(logits_processor) > 0:
for i in range(candidate_length + 1):