[tests] make cuda-only tests device-agnostic (#35222)

fix cuda-only tests
This commit is contained in:
Fanli Lin 2024-12-18 17:14:22 +08:00 committed by GitHub
parent 1eee1cedfd
commit c7e48053aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,7 +33,7 @@ from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_torch,
require_torch_non_multi_gpu,
require_torch_non_multi_accelerator,
slow,
torch_device,
)
@ -678,7 +678,7 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
@require_retrieval
@require_sentencepiece
@require_tokenizers
@require_torch_non_multi_gpu
@require_torch_non_multi_accelerator
class RagModelIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
@ -1002,7 +1002,7 @@ class RagModelIntegrationTests(unittest.TestCase):
torch_device
)
if torch_device == "cuda":
if torch_device != "cpu":
rag_token.half()
input_dict = tokenizer(