diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d2ec76091f3..a1f172801ad 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1641,7 +1641,7 @@ class TrainingArguments: self.do_eval = True if self.torch_empty_cache_steps is not None: - if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0): + if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0): raise ValueError( f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}." ) diff --git a/tests/test_training_args.py b/tests/test_training_args.py index a4da834582e..c207196abc8 100644 --- a/tests/test_training_args.py +++ b/tests/test_training_args.py @@ -40,3 +40,28 @@ class TestTrainingArguments(unittest.TestCase): self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist # Directory should be created when actually needed (e.g. in Trainer) + + def test_torch_empty_cache_steps_requirements(self): + """Test that torch_empty_cache_steps is a positive integer or None.""" + + # None is acceptable (feature is disabled): + args = TrainingArguments(torch_empty_cache_steps=None) + self.assertIsNone(args.torch_empty_cache_steps) + + # non-int is unacceptable: + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps=1.0) + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps="none") + + # negative int is unacceptable: + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps=-1) + + # zero is unacceptable: + with self.assertRaises(ValueError): + TrainingArguments(torch_empty_cache_steps=0) + + # positive int is acceptable: + args = TrainingArguments(torch_empty_cache_steps=1) + self.assertEqual(args.torch_empty_cache_steps, 1)