diff --git a/tests/models/levit/test_modeling_levit.py b/tests/models/levit/test_modeling_levit.py index 725b279fd02..2b3436f3d05 100644 --- a/tests/models/levit/test_modeling_levit.py +++ b/tests/models/levit/test_modeling_levit.py @@ -20,6 +20,8 @@ import unittest import warnings from math import ceil, floor +from packaging import version + from transformers import LevitConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.models.auto import get_values @@ -335,6 +337,11 @@ class LevitModelTest(ModelTesterMixin, unittest.TestCase): loss.backward() def test_problem_types(self): + + parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) + if parsed_torch_version_base.base_version.startswith("1.9"): + self.skipTest(reason="This test fails with PyTorch 1.9.x: some CUDA issue") + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() problem_types = [