mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 13:50:13 +06:00
Device agnostic testing (#25870)
* adds agnostic decorators and availability fns * renaming decorators and fixing imports * updating some representative example tests bloom, opt, and reformer for now * wip device agnostic functions * lru cache to device checking functions * adds `TRANSFORMERS_TEST_DEVICE_SPEC` if present, imports the target file and updates device to function mappings * comments `TRANSFORMERS_TEST_DEVICE_SPEC` code * extra checks on device name * `make style; make quality` * updates default functions for agnostic calls * applies suggestions from review * adds `is_torch_available` guard * Add spec file to docs, rename function dispatch names to backend_* * add backend import to docs example for spec file * change instances of to * Move register backend to before device check as per @statelesshz changes * make style * make opt test require fp16 to run --------- Co-authored-by: arsalanu <arsalanu@graphcore.ai> Co-authored-by: arsalanu <hzji210@gmail.com>
This commit is contained in:
parent
41496b95da
commit
9da451713d
@ -525,6 +525,25 @@ Certain devices will require an additional import after importing `torch` for th
|
|||||||
```bash
|
```bash
|
||||||
TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/utils/test_logging.py
|
TRANSFORMERS_TEST_BACKEND="torch_npu" pytest tests/utils/test_logging.py
|
||||||
```
|
```
|
||||||
|
Alternative backends may also require the replacement of device-specific functions. For example `torch.cuda.manual_seed` may need to be replaced with a device-specific seed setter like `torch.npu.manual_seed` to correctly set a random seed on the device. To specify a new backend with backend-specific device functions when running the test suite, create a Python device specification file in the format:
|
||||||
|
|
||||||
|
```
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
# !! Further additional imports can be added here !!
|
||||||
|
|
||||||
|
# Specify the device name (eg. 'cuda', 'cpu', 'npu')
|
||||||
|
DEVICE_NAME = 'npu'
|
||||||
|
|
||||||
|
# Specify device-specific backends to dispatch to.
|
||||||
|
# If not specified, will fallback to 'default' in 'testing_utils.py`
|
||||||
|
MANUAL_SEED_FN = torch.npu.manual_seed
|
||||||
|
EMPTY_CACHE_FN = torch.npu.empty_cache
|
||||||
|
DEVICE_COUNT_FN = torch.npu.device_count
|
||||||
|
```
|
||||||
|
This format also allows for specification of any additional imports required. To use this file to replace equivalent methods in the test suite, set the environment variable `TRANSFORMERS_TEST_DEVICE_SPEC` to the path of the spec file.
|
||||||
|
|
||||||
|
Currently, only `MANUAL_SEED_FN`, `EMPTY_CACHE_FN` and `DEVICE_COUNT_FN` are supported for device-specific dispatch.
|
||||||
|
|
||||||
|
|
||||||
### Distributed training
|
### Distributed training
|
||||||
|
@ -32,7 +32,7 @@ import unittest
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Iterator, List, Optional, Union
|
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
@ -98,8 +98,10 @@ from .utils import (
|
|||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_bf16_available_on_device,
|
||||||
is_torch_bf16_cpu_available,
|
is_torch_bf16_cpu_available,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
|
is_torch_fp16_available_on_device,
|
||||||
is_torch_neuroncore_available,
|
is_torch_neuroncore_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_tensorrt_fx_available,
|
is_torch_tensorrt_fx_available,
|
||||||
@ -713,6 +715,16 @@ if is_torch_available():
|
|||||||
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
|
||||||
|
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
|
||||||
|
try:
|
||||||
|
_ = importlib.import_module(backend)
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
|
||||||
|
f" traceback):\n{e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
if "TRANSFORMERS_TEST_DEVICE" in os.environ:
|
||||||
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
|
||||||
try:
|
try:
|
||||||
@ -730,17 +742,6 @@ if is_torch_available():
|
|||||||
torch_device = "xpu"
|
torch_device = "xpu"
|
||||||
else:
|
else:
|
||||||
torch_device = "cpu"
|
torch_device = "cpu"
|
||||||
|
|
||||||
if "TRANSFORMERS_TEST_BACKEND" in os.environ:
|
|
||||||
backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
|
|
||||||
try:
|
|
||||||
_ = importlib.import_module(backend)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
raise ModuleNotFoundError(
|
|
||||||
f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
|
|
||||||
f" traceback):\n{e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
torch_device = None
|
torch_device = None
|
||||||
|
|
||||||
@ -770,6 +771,25 @@ def require_torch_gpu(test_case):
|
|||||||
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_accelerator(test_case):
|
||||||
|
"""Decorator marking a test that requires an accessible accelerator and PyTorch."""
|
||||||
|
return unittest.skipUnless(torch_device != "cpu", "test requires accelerator")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_fp16(test_case):
|
||||||
|
"""Decorator marking a test that requires a device that supports fp16"""
|
||||||
|
return unittest.skipUnless(
|
||||||
|
is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_bf16(test_case):
|
||||||
|
"""Decorator marking a test that requires a device that supports bf16"""
|
||||||
|
return unittest.skipUnless(
|
||||||
|
is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
|
||||||
|
)(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_torch_bf16_gpu(test_case):
|
def require_torch_bf16_gpu(test_case):
|
||||||
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
|
"""Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
|
||||||
return unittest.skipUnless(
|
return unittest.skipUnless(
|
||||||
@ -2176,3 +2196,86 @@ class HfDoctestModule(Module):
|
|||||||
for test in finder.find(module, module.__name__):
|
for test in finder.find(module, module.__name__):
|
||||||
if test.examples: # skip empty doctests and cuda
|
if test.examples: # skip empty doctests and cuda
|
||||||
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
|
yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
|
||||||
|
|
||||||
|
|
||||||
|
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
|
||||||
|
if device not in dispatch_table:
|
||||||
|
return dispatch_table["default"](*args, **kwargs)
|
||||||
|
|
||||||
|
fn = dispatch_table[device]
|
||||||
|
|
||||||
|
# Some device agnostic functions return values. Need to guard against `None`
|
||||||
|
# instead at user level.
|
||||||
|
if fn is None:
|
||||||
|
return None
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
# Mappings from device names to callable functions to support device agnostic
|
||||||
|
# testing.
|
||||||
|
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
|
||||||
|
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None}
|
||||||
|
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1}
|
||||||
|
|
||||||
|
|
||||||
|
def backend_manual_seed(device: str, seed: int):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_empty_cache(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
|
||||||
|
|
||||||
|
|
||||||
|
def backend_device_count(device: str):
|
||||||
|
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
|
||||||
|
# into device to function mappings.
|
||||||
|
if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
|
||||||
|
device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
|
||||||
|
if not Path(device_spec_path).is_file():
|
||||||
|
raise ValueError(
|
||||||
|
f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to strip extension for later import – also verifies we are importing a
|
||||||
|
# python file.
|
||||||
|
try:
|
||||||
|
import_name = device_spec_path[: device_spec_path.index(".py")]
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
|
||||||
|
|
||||||
|
device_spec_module = importlib.import_module(import_name)
|
||||||
|
|
||||||
|
# Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
|
||||||
|
try:
|
||||||
|
device_name = device_spec_module.DEVICE_NAME
|
||||||
|
except AttributeError as e:
|
||||||
|
raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
|
||||||
|
|
||||||
|
if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
|
||||||
|
msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
|
||||||
|
msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
torch_device = device_name
|
||||||
|
|
||||||
|
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
|
||||||
|
try:
|
||||||
|
# Try to import the function directly
|
||||||
|
spec_fn = getattr(device_spec_module, attribute_name)
|
||||||
|
device_fn_dict[torch_device] = spec_fn
|
||||||
|
except AttributeError as e:
|
||||||
|
# If the function doesn't exist, and there is no default, throw an error
|
||||||
|
if "default" not in device_fn_dict:
|
||||||
|
raise AttributeError(
|
||||||
|
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Add one entry here for each `BACKEND_*` dictionary.
|
||||||
|
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
||||||
|
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
||||||
|
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
||||||
|
@ -166,10 +166,12 @@ from .import_utils import (
|
|||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_bf16_available,
|
is_torch_bf16_available,
|
||||||
|
is_torch_bf16_available_on_device,
|
||||||
is_torch_bf16_cpu_available,
|
is_torch_bf16_cpu_available,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_compile_available,
|
is_torch_compile_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
|
is_torch_fp16_available_on_device,
|
||||||
is_torch_fx_available,
|
is_torch_fx_available,
|
||||||
is_torch_fx_proxy,
|
is_torch_fx_proxy,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
|
@ -351,6 +351,45 @@ def is_torch_bf16_available():
|
|||||||
return is_torch_bf16_gpu_available()
|
return is_torch_bf16_gpu_available()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def is_torch_fp16_available_on_device(device):
|
||||||
|
if not is_torch_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
x = torch.zeros(2, 2, dtype=torch.float16).to(device)
|
||||||
|
_ = x @ x
|
||||||
|
except: # noqa: E722
|
||||||
|
# TODO: more precise exception matching, if possible.
|
||||||
|
# most backends should return `RuntimeError` however this is not guaranteed.
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def is_torch_bf16_available_on_device(device):
|
||||||
|
if not is_torch_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if device == "cuda":
|
||||||
|
return is_torch_bf16_gpu_available()
|
||||||
|
|
||||||
|
try:
|
||||||
|
x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device)
|
||||||
|
_ = x @ x
|
||||||
|
except: # noqa: E722
|
||||||
|
# TODO: more precise exception matching, if possible.
|
||||||
|
# most backends should return `RuntimeError` however this is not guaranteed.
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tf32_available():
|
def is_torch_tf32_available():
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
return False
|
return False
|
||||||
|
@ -18,7 +18,7 @@ import math
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import BloomConfig, is_torch_available
|
from transformers import BloomConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -401,7 +401,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_simple_generation(self):
|
def test_simple_generation(self):
|
||||||
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
|
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
|
||||||
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
|
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
|
||||||
@ -440,7 +440,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_batch_generation(self):
|
def test_batch_generation(self):
|
||||||
path_560m = "bigscience/bloom-560m"
|
path_560m = "bigscience/bloom-560m"
|
||||||
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device)
|
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device)
|
||||||
@ -460,7 +460,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_batch_generation_padd(self):
|
def test_batch_generation_padd(self):
|
||||||
path_560m = "bigscience/bloom-560m"
|
path_560m = "bigscience/bloom-560m"
|
||||||
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device)
|
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").to(torch_device)
|
||||||
@ -489,7 +489,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
)
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_batch_generated_text(self):
|
def test_batch_generated_text(self):
|
||||||
path_560m = "bigscience/bloom-560m"
|
path_560m = "bigscience/bloom-560m"
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import CodeGenConfig, is_torch_available
|
from transformers import CodeGenConfig, is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
|
from transformers.testing_utils import backend_manual_seed, is_flaky, require_torch, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -498,8 +498,7 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
if torch_device == "cuda":
|
backend_manual_seed(torch_device, 0)
|
||||||
torch.cuda.manual_seed(0)
|
|
||||||
|
|
||||||
tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True)
|
tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True)
|
||||||
input_ids = tokenized.input_ids.to(torch_device)
|
input_ids = tokenized.input_ids.to(torch_device)
|
||||||
|
@ -22,7 +22,7 @@ import unittest
|
|||||||
import timeout_decorator # noqa
|
import timeout_decorator # noqa
|
||||||
|
|
||||||
from transformers import OPTConfig, is_torch_available
|
from transformers import OPTConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_fp16, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -286,12 +286,12 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model(**inputs)[0]
|
model(**inputs)[0]
|
||||||
|
|
||||||
|
@require_torch_fp16
|
||||||
def test_generate_fp16(self):
|
def test_generate_fp16(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
input_ids = input_dict["input_ids"]
|
input_ids = input_dict["input_ids"]
|
||||||
attention_mask = input_ids.ne(1).to(torch_device)
|
attention_mask = input_ids.ne(1).to(torch_device)
|
||||||
model = OPTForCausalLM(config).eval().to(torch_device)
|
model = OPTForCausalLM(config).eval().to(torch_device)
|
||||||
if torch_device == "cuda":
|
|
||||||
model.half()
|
model.half()
|
||||||
model.generate(input_ids, attention_mask=attention_mask)
|
model.generate(input_ids, attention_mask=attention_mask)
|
||||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||||
|
@ -20,6 +20,7 @@ from transformers.testing_utils import (
|
|||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_fp16,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
@ -563,12 +564,12 @@ class ReformerTesterMixin:
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs)
|
self.model_tester.create_and_check_reformer_random_seed(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
@require_torch_fp16
|
||||||
def test_reformer_model_fp16_forward(self):
|
def test_reformer_model_fp16_forward(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs)
|
self.model_tester.create_and_check_reformer_model_fp16_forward(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
@require_torch_fp16
|
||||||
def test_reformer_model_fp16_generate(self):
|
def test_reformer_model_fp16_generate(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
|
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user