diff --git a/docs/source/en/testing.md b/docs/source/en/testing.md index 25df6e295a1..d704c84a628 100644 --- a/docs/source/en/testing.md +++ b/docs/source/en/testing.md @@ -525,6 +525,25 @@ Certain devices will require an additional import after importing `torch` for th ```bash TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/utils/test_logging.py ``` +Alternative backends may also require the replacement of device-specific functions. For example `torch.cuda.manual_seed` may need to be replaced with a device-specific seed setter like `torch.npu.manual_seed` to correctly set a random seed on the device. To specify a new backend with backend-specific device functions when running the test suite, create a Python device specification file in the format: + +``` +import torch +import torch_npu +# !! Further additional imports can be added here !! + +# Specify the device name (eg. 'cuda', 'cpu', 'npu') +DEVICE_NAME = 'npu' + +# Specify device-specific backends to dispatch to. +# If not specified, will fallback to 'default' in 'testing_utils.py` +MANUAL_SEED_FN = torch.npu.manual_seed +EMPTY_CACHE_FN = torch.npu.empty_cache +DEVICE_COUNT_FN = torch.npu.device_count +``` +This format also allows for specification of any additional imports required. To use this file to replace equivalent methods in the test suite, set the environment variable `TRANSFORMERS_TEST_DEVICE_SPEC` to the path of the spec file. + +Currently, only `MANUAL_SEED_FN`, `EMPTY_CACHE_FN` and `DEVICE_COUNT_FN` are supported for device-specific dispatch. ### Distributed training diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 50f50c83c4e..a5cac8ddfe9 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -32,7 +32,7 @@ import unittest from collections.abc import Mapping from io import StringIO from pathlib import Path -from typing import Iterable, Iterator, List, Optional, Union +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union from unittest import mock import huggingface_hub @@ -98,8 +98,10 @@ from .utils import ( is_timm_available, is_tokenizers_available, is_torch_available, + is_torch_bf16_available_on_device, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_fp16_available_on_device, is_torch_neuroncore_available, is_torch_npu_available, is_torch_tensorrt_fx_available, @@ -713,6 +715,16 @@ if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch + if "TRANSFORMERS_TEST_BACKEND" in os.environ: + backend = os.environ["TRANSFORMERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" + f" traceback):\n{e}" + ) from e + if "TRANSFORMERS_TEST_DEVICE" in os.environ: torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] try: @@ -730,17 +742,6 @@ if is_torch_available(): torch_device = "xpu" else: torch_device = "cpu" - - if "TRANSFORMERS_TEST_BACKEND" in os.environ: - backend = os.environ["TRANSFORMERS_TEST_BACKEND"] - try: - _ = importlib.import_module(backend) - except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" - f" traceback):\n{e}" - ) from e - else: torch_device = None @@ -770,6 +771,25 @@ def require_torch_gpu(test_case): return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accessible accelerator and PyTorch.""" + return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case) + + +def require_torch_fp16(test_case): + """Decorator marking a test that requires a device that supports fp16""" + return unittest.skipUnless( + is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support" + )(test_case) + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires a device that supports bf16""" + return unittest.skipUnless( + is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support" + )(test_case) + + def require_torch_bf16_gpu(test_case): """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0""" return unittest.skipUnless( @@ -2176,3 +2196,86 @@ class HfDoctestModule(Module): for test in finder.find(module, module.__name__): if test.examples: # skip empty doctests and cuda yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) + + +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against `None` + # instead at user level. + if fn is None: + return None + return fn(*args, **kwargs) + + +if is_torch_available(): + # Mappings from device names to callable functions to support device agnostic + # testing. + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None} + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1} + + +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +if is_torch_available(): + # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries + # into device to function mappings. + if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError( + f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}" + ) + + # Try to strip extension for later import – also verifies we are importing a + # python file. + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early. + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError as e: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index b66f0db3892..25dfa53d8d7 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -166,10 +166,12 @@ from .import_utils import ( is_tokenizers_available, is_torch_available, is_torch_bf16_available, + is_torch_bf16_available_on_device, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_compile_available, is_torch_cuda_available, + is_torch_fp16_available_on_device, is_torch_fx_available, is_torch_fx_proxy, is_torch_mps_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c135034d02b..48e4e702360 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -351,6 +351,45 @@ def is_torch_bf16_available(): return is_torch_bf16_gpu_available() +@lru_cache() +def is_torch_fp16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + try: + x = torch.zeros(2, 2, dtype=torch.float16).to(device) + _ = x @ x + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + +@lru_cache() +def is_torch_bf16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + if device == "cuda": + return is_torch_bf16_gpu_available() + + try: + x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device) + _ = x @ x + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + def is_torch_tf32_available(): if not is_torch_available(): return False diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index c05d45ebecc..5d090c5785f 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -18,7 +18,7 @@ import math import unittest from transformers import BloomConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -401,7 +401,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi self.assertIsNotNone(model) @slow - @require_torch_gpu + @require_torch_accelerator def test_simple_generation(self): # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 @@ -440,7 +440,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow - @require_torch_gpu + @require_torch_accelerator def test_batch_generation(self): path_560m = "bigscience/bloom-560m" model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device) @@ -460,7 +460,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) @slow - @require_torch_gpu + @require_torch_accelerator def test_batch_generation_padd(self): path_560m = "bigscience/bloom-560m" model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device) @@ -489,7 +489,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) @slow - @require_torch_gpu + @require_torch_accelerator def test_batch_generated_text(self): path_560m = "bigscience/bloom-560m" diff --git a/tests/models/codegen/test_modeling_codegen.py b/tests/models/codegen/test_modeling_codegen.py index 34a32caa7ff..e042ccac71d 100644 --- a/tests/models/codegen/test_modeling_codegen.py +++ b/tests/models/codegen/test_modeling_codegen.py @@ -19,7 +19,7 @@ import unittest from transformers import CodeGenConfig, is_torch_available from transformers.file_utils import cached_property -from transformers.testing_utils import is_flaky, require_torch, slow, torch_device +from transformers.testing_utils import backend_manual_seed, is_flaky, require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -498,8 +498,7 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase): model.to(torch_device) torch.manual_seed(0) - if torch_device == "cuda": - torch.cuda.manual_seed(0) + backend_manual_seed(torch_device, 0) tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True) input_ids = tokenized.input_ids.to(torch_device) diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index 232a29d4f50..18c0c9a7efe 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -22,7 +22,7 @@ import unittest import timeout_decorator # noqa from transformers import OPTConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_fp16, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -286,13 +286,13 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, with torch.no_grad(): model(**inputs)[0] + @require_torch_fp16 def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) model = OPTForCausalLM(config).eval().to(torch_device) - if torch_device == "cuda": - model.half() + model.half() model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index c84f729633c..11cd7e1a33b 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -20,6 +20,7 @@ from transformers.testing_utils import ( require_sentencepiece, require_tokenizers, require_torch, + require_torch_fp16, require_torch_multi_gpu, slow, torch_device, @@ -563,12 +564,12 @@ class ReformerTesterMixin: config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs) - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + @require_torch_fp16 def test_reformer_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs) - @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + @require_torch_fp16 def test_reformer_model_fp16_generate(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)