Generate: fix assistant in different device (#33257)

This commit is contained in:
Joao Gante 2024-09-02 14:37:49 +01:00 committed by GitHub
parent 52a0213755
commit 97c0f45b9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 1 deletions

View File

@ -3964,6 +3964,7 @@ class GenerationMixin:
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
candidate_input_ids = candidate_input_ids.to(self.device)
if candidate_logits is not None:
candidate_logits = candidate_logits.to(self.device)

View File

@ -3323,7 +3323,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
@require_torch_gpu
def test_assisted_decoding_in_gpu_cpu(self):
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(