Generate: fix candidate device placement (#28493)

* fix candidate device

* this line shouldn't have been in
This commit is contained in:
Joao Gante 2024-01-13 20:31:25 +00:00 committed by GitHub
parent e304f9769c
commit bc72b4e2cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -169,6 +169,8 @@ class AssistedCandidateGenerator(CandidateGenerator):
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None

View File

@ -4591,11 +4591,10 @@ class GenerationMixin:
cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(
input_ids.to(candidate_generator.assistant_model.device)
)
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
candidate_logits = candidate_logits.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)
candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
last_assistant_token_is_eos = (