Device agnostic testing (#25870)

* adds agnostic decorators and availability fns

* renaming decorators and fixing imports

* updating some representative example tests
bloom, opt, and reformer for now

* wip device agnostic functions

* lru cache to device checking functions

* adds `TRANSFORMERS_TEST_DEVICE_SPEC`
if present, imports the target file and updates device to function
mappings

* comments `TRANSFORMERS_TEST_DEVICE_SPEC` code

* extra checks on device name

* `make style; make quality`

* updates default functions for agnostic calls

* applies suggestions from review

* adds `is_torch_available` guard

* Add spec file to docs, rename function dispatch names to backend_*

* add backend import to docs example for spec file

* change instances of  to

* Move register backend to before device check as per @statelesshz changes

* make style

* make opt test require fp16 to run

---------

Co-authored-by: arsalanu <arsalanu@graphcore.ai>
Co-authored-by: arsalanu <hzji210@gmail.com>
This commit is contained in:
Alex McKinney 2023-10-24 15:49:26 +01:00 committed by GitHub
parent 41496b95da
commit 9da451713d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 188 additions and 25 deletions

View File

@ -525,6 +525,25 @@ Certain devices will require an additional import after importing `torch` for th
```bash
TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/utils/test_logging.py
```
Alternative backends may also require the replacement of device-specific functions. For example `torch.cuda.manual_seed` may need to be replaced with a device-specific seed setter like `torch.npu.manual_seed` to correctly set a random seed on the device. To specify a new backend with backend-specific device functions when running the test suite, create a Python device specification file in the format:
```
import torch
import torch_npu
# !! Further additional imports can be added here !!
# Specify the device name (eg. 'cuda', 'cpu', 'npu')
DEVICE_NAME = 'npu'
# Specify device-specific backends to dispatch to.
# If not specified, will fallback to 'default' in 'testing_utils.py`
MANUAL_SEED_FN = torch.npu.manual_seed
EMPTY_CACHE_FN = torch.npu.empty_cache
DEVICE_COUNT_FN = torch.npu.device_count
```
This format also allows for specification of any additional imports required. To use this file to replace equivalent methods in the test suite, set the environment variable `TRANSFORMERS_TEST_DEVICE_SPEC` to the path of the spec file.
Currently, only `MANUAL_SEED_FN`, `EMPTY_CACHE_FN` and `DEVICE_COUNT_FN` are supported for device-specific dispatch.
### Distributed training

View File

@ -32,7 +32,7 @@ import unittest
from collections.abc import Mapping
from io import StringIO
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Union
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock
import huggingface_hub
@ -98,8 +98,10 @@ from .utils import (
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_fp16_available_on_device,
is_torch_neuroncore_available,
is_torch_npu_available,
is_torch_tensorrt_fx_available,
@ -713,6 +715,16 @@ if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
f" traceback):\n{e}"
) from e
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
try:
@ -730,17 +742,6 @@ if is_torch_available():
torch_device = "xpu"
else:
torch_device = "cpu"
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
f" traceback):\n{e}"
) from e
else:
torch_device = None
@ -770,6 +771,25 @@ def require_torch_gpu(test_case):
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
def require_torch_fp16(test_case):
"""Decorator marking a test that requires a device that supports fp16"""
return unittest.skipUnless(
is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
)(test_case)
def require_torch_bf16(test_case):
"""Decorator marking a test that requires a device that supports bf16"""
return unittest.skipUnless(
is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
)(test_case)
def require_torch_bf16_gpu(test_case):
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
return unittest.skipUnless(
@ -2176,3 +2196,86 @@ class HfDoctestModule(Module):
for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests and cuda
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)
fn = dispatch_table[device]
# Some device agnostic functions return values. Need to guard against `None`
# instead at user level.
if fn is None:
return None
return fn(*args, **kwargs)
if is_torch_available():
# Mappings from device names to callable functions to support device agnostic
# testing.
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
if is_torch_available():
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
# into device to function mappings.
if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
if not Path(device_spec_path).is_file():
raise ValueError(
f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
)
# Try to strip extension for later import also verifies we are importing a
# python file.
try:
import_name = device_spec_path[: device_spec_path.index(".py")]
except ValueError as e:
raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
device_spec_module = importlib.import_module(import_name)
# Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
try:
device_name = device_spec_module.DEVICE_NAME
except AttributeError as e:
raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
raise ValueError(msg)
torch_device = device_name
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
try:
# Try to import the function directly
spec_fn = getattr(device_spec_module, attribute_name)
device_fn_dict[torch_device] = spec_fn
except AttributeError as e:
# If the function doesn't exist, and there is no default, throw an error
if "default" not in device_fn_dict:
raise AttributeError(
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
) from e
# Add one entry here for each `BACKEND_*` dictionary.
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")

View File

@ -166,10 +166,12 @@ from .import_utils import (
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_bf16_available_on_device,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_compile_available,
is_torch_cuda_available,
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_mps_available,

View File

@ -351,6 +351,45 @@ def is_torch_bf16_available():
return is_torch_bf16_gpu_available()
@lru_cache()
def is_torch_fp16_available_on_device(device):
if not is_torch_available():
return False
import torch
try:
x = torch.zeros(2, 2, dtype=torch.float16).to(device)
_ = x @ x
except: # noqa: E722
# TODO: more precise exception matching, if possible.
# most backends should return `RuntimeError` however this is not guaranteed.
return False
return True
@lru_cache()
def is_torch_bf16_available_on_device(device):
if not is_torch_available():
return False
import torch
if device == "cuda":
return is_torch_bf16_gpu_available()
try:
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
_ = x @ x
except: # noqa: E722
# TODO: more precise exception matching, if possible.
# most backends should return `RuntimeError` however this is not guaranteed.
return False
return True
def is_torch_tf32_available():
if not is_torch_available():
return False

View File

@ -18,7 +18,7 @@ import math
import unittest
from transformers import BloomConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -401,7 +401,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertIsNotNone(model)
@slow
@require_torch_gpu
@require_torch_accelerator
def test_simple_generation(self):
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
@ -440,7 +440,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
@slow
@require_torch_gpu
@require_torch_accelerator
def test_batch_generation(self):
path_560m = "bigscience/bloom-560m"
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device)
@ -460,7 +460,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
@slow
@require_torch_gpu
@require_torch_accelerator
def test_batch_generation_padd(self):
path_560m = "bigscience/bloom-560m"
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device)
@ -489,7 +489,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
)
@slow
@require_torch_gpu
@require_torch_accelerator
def test_batch_generated_text(self):
path_560m = "bigscience/bloom-560m"

View File

@ -19,7 +19,7 @@ import unittest
from transformers import CodeGenConfig, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
from transformers.testing_utils import backend_manual_seed, is_flaky, require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -498,8 +498,7 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
model.to(torch_device)
torch.manual_seed(0)
if torch_device == "cuda":
torch.cuda.manual_seed(0)
backend_manual_seed(torch_device, 0)
tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)

View File

@ -22,7 +22,7 @@ import unittest
import timeout_decorator # noqa
from transformers import OPTConfig, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import require_torch, require_torch_fp16, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@ -286,13 +286,13 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
with torch.no_grad():
model(**inputs)[0]
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
model = OPTForCausalLM(config).eval().to(torch_device)
if torch_device == "cuda":
model.half()
model.half()
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)

View File

@ -20,6 +20,7 @@ from transformers.testing_utils import (
require_sentencepiece,
require_tokenizers,
require_torch,
require_torch_fp16,
require_torch_multi_gpu,
slow,
torch_device,
@ -563,12 +564,12 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
@require_torch_fp16
def test_reformer_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
@require_torch_fp16
def test_reformer_model_fp16_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)