mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Adds TRANSFORMERS_TEST_DEVICE
(#25506)
* Adds `TRANSFORMERS_TEST_DEVICE` Mirrors the same API in the diffusers library. Useful in transformers too. * replace backend checking with trying `torch.device` * Adds better error message for unknown test devices * `make style` * adds documentation showing `TRANSFORMERS_TEST_DEVICE` usage.
This commit is contained in:
parent
e7e9261a20
commit
1791ef8df6
@ -511,6 +511,16 @@ from transformers.testing_utils import get_gpu_count
|
||||
n_gpu = get_gpu_count() # works with torch and tf
|
||||
```
|
||||
|
||||
### Testing with a specific PyTorch backend
|
||||
|
||||
To run the test suite on a specific torch backend add `TRANSFORMERS_TEST_DEVICE="$device"` where `$device` is the target backend. For example, to test on CPU only:
|
||||
```bash
|
||||
TRANSFORMERS_TEST_DEVICE="cpu" pytest tests/test_logging.py
|
||||
```
|
||||
|
||||
This variable is useful for testing custom or less common PyTorch backends such as `mps`. It can also be used to achieve the same effect as `CUDA_VISIBLE_DEVICES` by targeting specific GPUs or testing in CPU-only mode.
|
||||
|
||||
|
||||
### Distributed training
|
||||
|
||||
`pytest` can't deal with distributed training directly. If this is attempted - the sub-processes don't do the right
|
||||
|
@ -614,7 +614,16 @@ if is_torch_available():
|
||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
||||
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
||||
try:
|
||||
# try creating device to see if provided device is valid
|
||||
_ = torch.device(torch_device)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
|
||||
) from e
|
||||
elif torch.cuda.is_available():
|
||||
torch_device = "cuda"
|
||||
elif _run_third_party_device_tests and is_torch_npu_available():
|
||||
torch_device = "npu"
|
||||
|
Loading…
Reference in New Issue
Block a user