mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[testing] rename skip targets + docs (#7863)
* rename skip targets + docs * fix quotes * style * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * small improvements * fix Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c912ba5f69
commit
3e31e7f956
@ -400,29 +400,46 @@ or if you have multiple gpus, you can specify which one is to be used by ``pytes
|
||||
CUDA_VISIBLE_DEVICES="1" pytest tests/test_logging.py
|
||||
|
||||
This is handy when you want to run different tasks on different GPUs.
|
||||
|
||||
And we have these decorators that require the condition described by the marker.
|
||||
|
||||
``
|
||||
@require_torch
|
||||
@require_tf
|
||||
@require_multigpu
|
||||
@require_non_multigpu
|
||||
@require_torch_tpu
|
||||
@require_torch_and_cuda
|
||||
``
|
||||
Some tests must be run on CPU-only, others on either CPU or GPU or TPU, yet others on multiple-GPUs. The following skip decorators are used to set the requirements of tests CPU/GPU/TPU-wise:
|
||||
|
||||
* ``require_torch`` - this test will run only under torch
|
||||
* ``require_torch_gpu`` - as ``require_torch`` plus requires at least 1 GPU
|
||||
* ``require_torch_multigpu`` - as ``require_torch`` plus requires at least 2 GPUs
|
||||
* ``require_torch_non_multigpu`` - as ``require_torch`` plus requires 0 or 1 GPUs
|
||||
* ``require_torch_tpu`` - as ``require_torch`` plus requires at least 1 TPU
|
||||
|
||||
For example, here is a test that must be run only when there are 2 or more GPUs available and pytorch is installed:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@require_torch_multigpu
|
||||
def test_example_with_multigpu():
|
||||
|
||||
If a test requires ``tensorflow`` use the ``require_tf`` decorator. For example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@require_tf
|
||||
def test_tf_thing_with_tensorflow():
|
||||
|
||||
These decorators can be stacked. For example, if a test is slow and requires at least one GPU under pytorch, here is how to set it up:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_example_slow_on_gpu():
|
||||
|
||||
Some decorators like ``@parametrized`` rewrite test names, therefore ``@require_*`` skip decorators have to be listed last for them to work correctly. Here is an example of the correct usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@parameterized.expand(...)
|
||||
@require_multigpu
|
||||
@require_torch_multigpu
|
||||
def test_integration_foo():
|
||||
|
||||
There is no problem whatsoever with ``@pytest.mark.parametrize`` (but it only works with non-unittests) - can use it in any order.
|
||||
|
||||
This section will be expanded soon once our work in progress on those decorators is finished.
|
||||
This order problem doesn't exist with ``@pytest.mark.parametrize``, you can put it first or last and it will still work. But it only works with non-unittests.
|
||||
|
||||
Inside tests:
|
||||
|
||||
|
@ -19,7 +19,7 @@ from run_eval import generate_summaries_or_translations, run_generate
|
||||
from run_eval_search import run_search
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
||||
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
@ -125,9 +125,9 @@ class TestSummarizationDistiller(TestCasePlus):
|
||||
return cls
|
||||
|
||||
@slow
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
def test_hub_configs(self):
|
||||
"""I put require_torch_and_cuda cause I only want this to run with self-scheduled."""
|
||||
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
|
||||
|
||||
model_list = HfApi().model_list()
|
||||
org = "sshleifer"
|
||||
|
@ -154,7 +154,7 @@ def require_tokenizers(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_multigpu(test_case):
|
||||
def require_torch_multigpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
|
||||
|
||||
@ -174,7 +174,7 @@ def require_multigpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_non_multigpu(test_case):
|
||||
def require_torch_non_multigpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
|
||||
"""
|
||||
@ -208,7 +208,7 @@ else:
|
||||
torch_device = None
|
||||
|
||||
|
||||
def require_torch_and_cuda(test_case):
|
||||
def require_torch_gpu(test_case):
|
||||
"""Decorator marking a test that requires CUDA and PyTorch. """
|
||||
if torch_device != "cuda":
|
||||
return unittest.skip("test requires CUDA")(test_case)
|
||||
|
@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_and_cuda, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
@ -302,6 +302,6 @@ class XxxModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""Test loss/gradients same as reference implementation, for example."""
|
||||
pass
|
||||
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
def test_large_inputs_in_fp16_dont_cause_overflow(self):
|
||||
pass
|
||||
|
@ -22,7 +22,7 @@ import unittest
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -980,7 +980,7 @@ class ModelTesterMixin:
|
||||
return True
|
||||
return False
|
||||
|
||||
@require_multigpu
|
||||
@require_torch_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, require_torch_and_cuda, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
@ -234,6 +234,6 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"""Test loss/gradients same as reference implementation, for example."""
|
||||
pass
|
||||
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
def test_large_inputs_in_fp16_dont_cause_overflow(self):
|
||||
pass
|
||||
|
@ -17,10 +17,10 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
require_multigpu,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_multigpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -558,7 +558,7 @@ class ReformerTesterMixin:
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
|
||||
|
||||
@require_multigpu
|
||||
@require_torch_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
# Opt-out of this test.
|
||||
pass
|
||||
|
@ -17,7 +17,7 @@ import random
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
@ -204,7 +204,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
@require_multigpu
|
||||
@require_torch_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
# Opt-out of this test.
|
||||
pass
|
||||
|
@ -34,7 +34,7 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers.testing_utils import require_torch, require_torch_and_cuda, slow, torch_device
|
||||
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
# skipping in unittest tests
|
||||
@ -63,11 +63,11 @@ def check_slow_torch_cuda():
|
||||
@require_torch
|
||||
class SkipTester(unittest.TestCase):
|
||||
@slow
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
def test_2_skips_slow_first(self):
|
||||
check_slow_torch_cuda()
|
||||
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_2_skips_slow_last(self):
|
||||
check_slow_torch_cuda()
|
||||
@ -97,12 +97,12 @@ class SkipTester(unittest.TestCase):
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
def test_pytest_2_skips_slow_first():
|
||||
check_slow_torch_cuda()
|
||||
|
||||
|
||||
@require_torch_and_cuda
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_pytest_2_skips_slow_last():
|
||||
check_slow_torch_cuda()
|
||||
|
Loading…
Reference in New Issue
Block a user