diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index aa56b87dee6..3b43fddf548 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3748,11 +3748,13 @@ class GenerationIntegrationTests(unittest.TestCase): self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) @slow - @require_torch_multi_gpu + @require_torch_multi_accelerator def test_assisted_decoding_in_different_gpu(self): - model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") + device_0 = f"{torch_device}:0" if torch_device != "cpu" else "cpu" + device_1 = f"{torch_device}:1" if torch_device != "cpu" else "cpu" + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(device_0) assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( - "cuda:1" + device_1 ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") model.config.pad_token_id = tokenizer.eos_token_id