Simplify running tests in a subprocess (#34213)

* check

* check

* check

* check

* add docstring

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-10-29 10:48:57 +01:00 committed by GitHub
parent a1835195d1
commit 439334c8fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 52 additions and 9 deletions

View File

@ -2366,6 +2366,46 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
test_case.fail(f'{results["error"]}') test_case.fail(f'{results["error"]}')
def run_test_using_subprocess(func):
"""
To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
"""
import pytest
@functools.wraps(func)
def wrapper(*args, **kwargs):
if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
func(*args, **kwargs)
else:
test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
try:
import copy
env = copy.deepcopy(os.environ)
env["_INSIDE_SUB_PROCESS"] = "1"
# If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
if "pytestconfig" in kwargs:
command = list(kwargs["pytestconfig"].invocation_params.args)
for idx, x in enumerate(command):
if x in kwargs["pytestconfig"].args:
test = test.split("::")[1:]
command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
command = [f"{sys.executable}", "-m", "pytest"] + command
command = [x for x in command if x not in ["--no-summary"]]
# Otherwise, simply run the test with no option at all
else:
command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
subprocess.run(command, env=env, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
exception_message = e.stdout.decode()
raise pytest.fail(exception_message, pytrace=False)
return wrapper
""" """
The following contains utils to run the documentation tests without having to overwrite any files. The following contains utils to run the documentation tests without having to overwrite any files.

View File

@ -18,7 +18,7 @@ import inspect
import unittest import unittest
from transformers import ImageGPTConfig from transformers import ImageGPTConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, run_test_using_subprocess, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
@ -257,11 +257,9 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
self.assertEqual(len(scores), length) self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
@unittest.skip( @run_test_using_subprocess
reason="After #33632, this test still passes, but many subsequential tests fail with `device-side assert triggered`"
)
def test_beam_search_generate_dict_outputs_use_cache(self): def test_beam_search_generate_dict_outputs_use_cache(self):
pass super().test_beam_search_generate_dict_outputs_use_cache()
def setUp(self): def setUp(self):
self.model_tester = ImageGPTModelTester(self) self.model_tester = ImageGPTModelTester(self)

View File

@ -28,7 +28,14 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu,
run_test_using_subprocess,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
@ -248,9 +255,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass pass
@unittest.skip( @run_test_using_subprocess
reason="After #33533, this still passes, but many subsequential tests fail with `device-side assert triggered`"
)
def test_mixed_input(self): def test_mixed_input(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes: