diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 2781e9e102e..0eef286732d 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -2366,6 +2366,46 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): test_case.fail(f'{results["error"]}') +def run_test_using_subprocess(func): + """ + To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory + issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`). + """ + import pytest + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if os.getenv("_INSIDE_SUB_PROCESS", None) == "1": + func(*args, **kwargs) + else: + test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1]) + try: + import copy + + env = copy.deepcopy(os.environ) + env["_INSIDE_SUB_PROCESS"] = "1" + + # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments + if "pytestconfig" in kwargs: + command = list(kwargs["pytestconfig"].invocation_params.args) + for idx, x in enumerate(command): + if x in kwargs["pytestconfig"].args: + test = test.split("::")[1:] + command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test) + command = [f"{sys.executable}", "-m", "pytest"] + command + command = [x for x in command if x not in ["--no-summary"]] + # Otherwise, simply run the test with no option at all + else: + command = [f"{sys.executable}", "-m", "pytest", f"{test}"] + + subprocess.run(command, env=env, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + exception_message = e.stdout.decode() + raise pytest.fail(exception_message, pytrace=False) + + return wrapper + + """ The following contains utils to run the documentation tests without having to overwrite any files. diff --git a/tests/models/imagegpt/test_modeling_imagegpt.py b/tests/models/imagegpt/test_modeling_imagegpt.py index 07972675528..cdbe815431f 100644 --- a/tests/models/imagegpt/test_modeling_imagegpt.py +++ b/tests/models/imagegpt/test_modeling_imagegpt.py @@ -18,7 +18,7 @@ import inspect import unittest from transformers import ImageGPTConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import require_torch, require_vision, run_test_using_subprocess, slow, torch_device from transformers.utils import cached_property, is_torch_available, is_vision_available from ...generation.test_utils import GenerationTesterMixin @@ -257,11 +257,9 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM self.assertEqual(len(scores), length) self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores)) - @unittest.skip( - reason="After #33632, this test still passes, but many subsequential tests fail with `device-side assert triggered`" - ) + @run_test_using_subprocess def test_beam_search_generate_dict_outputs_use_cache(self): - pass + super().test_beam_search_generate_dict_outputs_use_cache() def setUp(self): self.model_tester = ImageGPTModelTester(self) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 1bd01843981..fd4c49f4a69 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -28,7 +28,14 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import ( + require_bitsandbytes, + require_torch, + require_torch_gpu, + run_test_using_subprocess, + slow, + torch_device, +) from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -248,9 +255,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass - @unittest.skip( - reason="After #33533, this still passes, but many subsequential tests fail with `device-side assert triggered`" - ) + @run_test_using_subprocess def test_mixed_input(self): config, inputs = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: