mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Simplify running tests in a subprocess (#34213)
* check * check * check * check * add docstring --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a1835195d1
commit
439334c8fb
@ -2366,6 +2366,46 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
|||||||
test_case.fail(f'{results["error"]}')
|
test_case.fail(f'{results["error"]}')
|
||||||
|
|
||||||
|
|
||||||
|
def run_test_using_subprocess(func):
|
||||||
|
"""
|
||||||
|
To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
|
||||||
|
issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
|
||||||
|
func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
|
||||||
|
try:
|
||||||
|
import copy
|
||||||
|
|
||||||
|
env = copy.deepcopy(os.environ)
|
||||||
|
env["_INSIDE_SUB_PROCESS"] = "1"
|
||||||
|
|
||||||
|
# If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
|
||||||
|
if "pytestconfig" in kwargs:
|
||||||
|
command = list(kwargs["pytestconfig"].invocation_params.args)
|
||||||
|
for idx, x in enumerate(command):
|
||||||
|
if x in kwargs["pytestconfig"].args:
|
||||||
|
test = test.split("::")[1:]
|
||||||
|
command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
|
||||||
|
command = [f"{sys.executable}", "-m", "pytest"] + command
|
||||||
|
command = [x for x in command if x not in ["--no-summary"]]
|
||||||
|
# Otherwise, simply run the test with no option at all
|
||||||
|
else:
|
||||||
|
command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
|
||||||
|
|
||||||
|
subprocess.run(command, env=env, check=True, capture_output=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
exception_message = e.stdout.decode()
|
||||||
|
raise pytest.fail(exception_message, pytrace=False)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The following contains utils to run the documentation tests without having to overwrite any files.
|
The following contains utils to run the documentation tests without having to overwrite any files.
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ import inspect
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import ImageGPTConfig
|
from transformers import ImageGPTConfig
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import require_torch, require_vision, run_test_using_subprocess, slow, torch_device
|
||||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@ -257,11 +257,9 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
self.assertEqual(len(scores), length)
|
self.assertEqual(len(scores), length)
|
||||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||||
|
|
||||||
@unittest.skip(
|
@run_test_using_subprocess
|
||||||
reason="After #33632, this test still passes, but many subsequential tests fail with `device-side assert triggered`"
|
|
||||||
)
|
|
||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
pass
|
super().test_beam_search_generate_dict_outputs_use_cache()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ImageGPTModelTester(self)
|
self.model_tester = ImageGPTModelTester(self)
|
||||||
|
@ -28,7 +28,14 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
require_bitsandbytes,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
run_test_using_subprocess,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -248,9 +255,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
|
|||||||
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(
|
@run_test_using_subprocess
|
||||||
reason="After #33533, this still passes, but many subsequential tests fail with `device-side assert triggered`"
|
|
||||||
)
|
|
||||||
def test_mixed_input(self):
|
def test_mixed_input(self):
|
||||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
Loading…
Reference in New Issue
Block a user