Fix TrainingArguments.torch_empty_cache_steps post_init check (#36734)

Mistaken use of De Morgan's law. Fixed "not (X or Y)"
to correct "not (X and Y)" check to raise a ValueError.

Added corresponding test to check "positive int or None" condition.

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
Petr Kuderov 2025-03-17 18:09:46 +03:00 committed by GitHub
parent 8e67230860
commit c8a2b25f91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 1 deletions

View File

@ -1641,7 +1641,7 @@ class TrainingArguments:
self.do_eval = True
if self.torch_empty_cache_steps is not None:
if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0):
if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0):
raise ValueError(
f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
)

View File

@ -40,3 +40,28 @@ class TestTrainingArguments(unittest.TestCase):
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
# Directory should be created when actually needed (e.g. in Trainer)
def test_torch_empty_cache_steps_requirements(self):
"""Test that torch_empty_cache_steps is a positive integer or None."""
# None is acceptable (feature is disabled):
args = TrainingArguments(torch_empty_cache_steps=None)
self.assertIsNone(args.torch_empty_cache_steps)
# non-int is unacceptable:
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps=1.0)
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps="none")
# negative int is unacceptable:
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps=-1)
# zero is unacceptable:
with self.assertRaises(ValueError):
TrainingArguments(torch_empty_cache_steps=0)
# positive int is acceptable:
args = TrainingArguments(torch_empty_cache_steps=1)
self.assertEqual(args.torch_empty_cache_steps, 1)