mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Generate: fix candidate device placement (#28493)
* fix candidate device * this line shouldn't have been in
This commit is contained in:
parent
e304f9769c
commit
bc72b4e2cd
@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||||
vocabulary_size)` containing the logits associated to each candidate.
|
||||
"""
|
||||
input_ids = input_ids.to(self.assistant_model.device)
|
||||
|
||||
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
||||
# (which implicitly contains the number of accepted candidates from the previous round)
|
||||
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
||||
|
@ -4591,11 +4591,10 @@ class GenerationMixin:
|
||||
cur_len = input_ids.shape[-1]
|
||||
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(
|
||||
input_ids.to(candidate_generator.assistant_model.device)
|
||||
)
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||
candidate_input_ids = candidate_input_ids.to(self.device)
|
||||
candidate_logits = candidate_logits.to(self.device)
|
||||
if candidate_logits is not None:
|
||||
candidate_logits = candidate_logits.to(self.device)
|
||||
|
||||
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
|
||||
last_assistant_token_is_eos = (
|
||||
|
Loading…
Reference in New Issue
Block a user