mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Fix TrainingArguments.torch_empty_cache_steps
post_init check (#36734)
Mistaken use of De Morgan's law. Fixed "not (X or Y)" to correct "not (X and Y)" check to raise a ValueError. Added corresponding test to check "positive int or None" condition. Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
parent
8e67230860
commit
c8a2b25f91
@ -1641,7 +1641,7 @@ class TrainingArguments:
|
|||||||
self.do_eval = True
|
self.do_eval = True
|
||||||
|
|
||||||
if self.torch_empty_cache_steps is not None:
|
if self.torch_empty_cache_steps is not None:
|
||||||
if not (isinstance(self.torch_empty_cache_steps, int) or self.torch_empty_cache_steps > 0):
|
if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
|
f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
|
||||||
)
|
)
|
||||||
|
@ -40,3 +40,28 @@ class TestTrainingArguments(unittest.TestCase):
|
|||||||
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
|
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
|
||||||
|
|
||||||
# Directory should be created when actually needed (e.g. in Trainer)
|
# Directory should be created when actually needed (e.g. in Trainer)
|
||||||
|
|
||||||
|
def test_torch_empty_cache_steps_requirements(self):
|
||||||
|
"""Test that torch_empty_cache_steps is a positive integer or None."""
|
||||||
|
|
||||||
|
# None is acceptable (feature is disabled):
|
||||||
|
args = TrainingArguments(torch_empty_cache_steps=None)
|
||||||
|
self.assertIsNone(args.torch_empty_cache_steps)
|
||||||
|
|
||||||
|
# non-int is unacceptable:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
TrainingArguments(torch_empty_cache_steps=1.0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
TrainingArguments(torch_empty_cache_steps="none")
|
||||||
|
|
||||||
|
# negative int is unacceptable:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
TrainingArguments(torch_empty_cache_steps=-1)
|
||||||
|
|
||||||
|
# zero is unacceptable:
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
TrainingArguments(torch_empty_cache_steps=0)
|
||||||
|
|
||||||
|
# positive int is acceptable:
|
||||||
|
args = TrainingArguments(torch_empty_cache_steps=1)
|
||||||
|
self.assertEqual(args.torch_empty_cache_steps, 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user