mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Generate: fix assistant in different device (#33257)
This commit is contained in:
parent
52a0213755
commit
97c0f45b9c
@ -3964,6 +3964,7 @@ class GenerationMixin:
|
||||
|
||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||
candidate_input_ids = candidate_input_ids.to(self.device)
|
||||
if candidate_logits is not None:
|
||||
candidate_logits = candidate_logits.to(self.device)
|
||||
|
||||
|
@ -3323,7 +3323,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_assisted_decoding_in_gpu_cpu(self):
|
||||
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
|
Loading…
Reference in New Issue
Block a user