mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
parent
5ee9e786d1
commit
d314ce70bf
@ -2623,6 +2623,7 @@ class GenerationMixin:
|
||||
next_token_logits = _dola_select_contrast(
|
||||
candidate_premature_layers, candidate_premature_logits, final_logits
|
||||
)
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
|
||||
@ -2794,6 +2795,7 @@ class GenerationMixin:
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
logit_for_next_step = outputs.logits[:, -1, :].clone().float()
|
||||
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
||||
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs,
|
||||
@ -2988,6 +2990,7 @@ class GenerationMixin:
|
||||
next_past_key_values = tuple(new_key_values)
|
||||
|
||||
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
||||
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
||||
|
||||
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
|
||||
if self.config.is_encoder_decoder:
|
||||
@ -3184,6 +3187,7 @@ class GenerationMixin:
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
next_token_logits = outputs.logits.clone()[:, -1, :].float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
@ -3434,6 +3438,7 @@ class GenerationMixin:
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
@ -3691,6 +3696,7 @@ class GenerationMixin:
|
||||
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
||||
# (the clone itself is always small)
|
||||
raw_logit_score = outputs.logits[:, -1, :].clone()
|
||||
raw_logit_score = raw_logit_score.to(input_ids.device)
|
||||
|
||||
for beam_group_idx in range(num_beam_groups):
|
||||
group_start_idx = beam_group_idx * num_sub_beams
|
||||
@ -3710,6 +3716,7 @@ class GenerationMixin:
|
||||
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[batch_group_indices, -1, :].float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
@ -3967,6 +3974,7 @@ class GenerationMixin:
|
||||
# (the clone itself is always small)
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
next_token_logits = outputs.logits[:, -1, :].clone().float()
|
||||
next_token_logits = next_token_logits.to(input_ids.device)
|
||||
next_token_scores = nn.functional.log_softmax(
|
||||
next_token_logits, dim=-1
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
@ -4215,6 +4223,7 @@ class GenerationMixin:
|
||||
# 2.3. Process the new logits
|
||||
# .float() is needed to retain precision for later logits manipulations
|
||||
new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
|
||||
new_logits = new_logits.to(input_ids.device)
|
||||
next_token_logits = new_logits.clone()
|
||||
if len(logits_processor) > 0:
|
||||
for i in range(candidate_length + 1):
|
||||
|
Loading…
Reference in New Issue
Block a user