tests: revert change of torch_require_multi_gpu to be device agnostic (#35721)

* tests: revert change of torch_require_multi_gpu to be device agnostic

The 11c27dd33 modified `torch_require_multi_gpu()` to be device agnostic
instead of being CUDA specific. This broke some tests which are rightfully
CUDA specific, such as:

* `tests/trainer/test_trainer_distributed.py::TestTrainerDistributed`

In the current Transformers tests architecture `require_torch_multi_accelerator()`
should be used to mark multi-GPU tests agnostic to device.

This change addresses the issue introduced by 11c27dd33 and reverts
modification of `torch_require_multi_gpu()`.

Fixes: 11c27dd33 ("Enable BNB multi-backend support (#31098)")
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* fix bug: modification of frozen set

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Dmitry Rogozhkin 2025-02-25 04:36:10 -08:00 committed by GitHub
parent d80d52b007
commit b4b9da6d9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 10 additions and 21 deletions

View File

@ -486,7 +486,7 @@ def _validate_bnb_multi_backend_availability(raise_exception):
import bitsandbytes as bnb
bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
available_devices = get_available_devices()
available_devices = set(get_available_devices())
if available_devices == {"cpu"} and not is_ipex_available():
from importlib.util import find_spec

View File

@ -238,17 +238,6 @@ _run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
def get_device_count():
import torch
if is_torch_xpu_available():
num_devices = torch.xpu.device_count()
else:
num_devices = torch.cuda.device_count()
return num_devices
def is_staging_test(test_case):
"""
Decorator marking a test as a staging test.
@ -756,17 +745,17 @@ def require_spacy(test_case):
def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
multiple GPUs.
Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
multiple CUDA GPUs.
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip(reason="test requires PyTorch")(test_case)
device_count = get_device_count()
import torch
return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case)
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)
def require_torch_multi_accelerator(test_case):

View File

@ -39,7 +39,7 @@ from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_torch_multi_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
@ -517,7 +517,7 @@ class Pipeline4BitTest(Base4bitTest):
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class Bnb4bitTestMultiGpu(Base4bitTest):
def setUp(self):

View File

@ -39,7 +39,7 @@ from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
require_torch_multi_gpu,
require_torch_multi_accelerator,
slow,
torch_device,
)
@ -671,7 +671,7 @@ class MixedInt8TestPipeline(BaseMixedInt8Test):
self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class MixedInt8TestMultiGpu(BaseMixedInt8Test):
def setUp(self):
@ -700,7 +700,7 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
@require_torch_multi_accelerator
@apply_skip_if_not_implemented
class MixedInt8TestCpuGpu(BaseMixedInt8Test):
def setUp(self):