From c8a2b25f915a7745d57c92635415e2517b739bc8 Mon Sep 17 00:00:00 2001 From: Petr Kuderov Date: Mon, 17 Mar 2025 18:09:46 +0300 Subject: [PATCH] Fix `TrainingArguments.torch_empty_cache_steps` post_init check (#36734) Mistaken use of De Morgan's law. Fixed "not (X or Y)" to correct "not (X and Y)" check to raise a ValueError. Added corresponding test to check "positive int or None" condition. Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- src/transformers/training_args.py | 2 +- tests/test_training_args.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d2ec76091f3..a1f172801ad 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1641,7 +1641,7 @@ class TrainingArguments: self.do_eval = True if self.torch_empty_cache_steps is not None: - if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0): + if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0): raise ValueError( f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}." ) diff --git a/tests/test_training_args.py b/tests/test_training_args.py index a4da834582e..c207196abc8 100644 --- a/tests/test_training_args.py +++ b/tests/test_training_args.py @@ -40,3 +40,28 @@ class TestTrainingArguments(unittest.TestCase): self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist # Directory should be created when actually needed (e.g. in Trainer) + + def test_torch_empty_cache_steps_requirements(self): + """Test that torch_empty_cache_steps is a positive integer or None.""" + + # None is acceptable (feature is disabled): + args = TrainingArguments(torch_empty_cache_steps=None) + self.assertIsNone(args.torch_empty_cache_steps) + + # non-int is unacceptable: + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps=1.0) + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps="none") + + # negative int is unacceptable: + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps=-1) + + # zero is unacceptable: + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps=0) + + # positive int is acceptable: + args = TrainingArguments(torch_empty_cache_steps=1) + self.assertEqual(args.torch_empty_cache_steps, 1)