mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00

Mistaken use of De Morgan's law. Fixed "not (X or Y)" to correct "not (X and Y)" check to raise a ValueError. Added corresponding test to check "positive int or None" condition. Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
68 lines
2.6 KiB
Python
68 lines
2.6 KiB
Python
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
from transformers import TrainingArguments
|
|
|
|
|
|
class TestTrainingArguments(unittest.TestCase):
|
|
def test_default_output_dir(self):
|
|
"""Test that output_dir defaults to 'trainer_output' when not specified."""
|
|
args = TrainingArguments(output_dir=None)
|
|
self.assertEqual(args.output_dir, "trainer_output")
|
|
|
|
def test_custom_output_dir(self):
|
|
"""Test that output_dir is respected when specified."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
args = TrainingArguments(output_dir=tmp_dir)
|
|
self.assertEqual(args.output_dir, tmp_dir)
|
|
|
|
def test_output_dir_creation(self):
|
|
"""Test that output_dir is created only when needed."""
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
output_dir = os.path.join(tmp_dir, "test_output")
|
|
|
|
# Directory should not exist before creating args
|
|
self.assertFalse(os.path.exists(output_dir))
|
|
|
|
# Create args with save_strategy="no" - should not create directory
|
|
args = TrainingArguments(
|
|
output_dir=output_dir,
|
|
do_train=True,
|
|
save_strategy="no",
|
|
report_to=None,
|
|
)
|
|
self.assertFalse(os.path.exists(output_dir))
|
|
|
|
# Now set save_strategy="steps" - should create directory when needed
|
|
args.save_strategy = "steps"
|
|
args.save_steps = 1
|
|
self.assertFalse(os.path.exists(output_dir)) # Still shouldn't exist
|
|
|
|
# Directory should be created when actually needed (e.g. in Trainer)
|
|
|
|
def test_torch_empty_cache_steps_requirements(self):
|
|
"""Test that torch_empty_cache_steps is a positive integer or None."""
|
|
|
|
# None is acceptable (feature is disabled):
|
|
args = TrainingArguments(torch_empty_cache_steps=None)
|
|
self.assertIsNone(args.torch_empty_cache_steps)
|
|
|
|
# non-int is unacceptable:
|
|
with self.assertRaises(ValueError):
|
|
TrainingArguments(torch_empty_cache_steps=1.0)
|
|
with self.assertRaises(ValueError):
|
|
TrainingArguments(torch_empty_cache_steps="none")
|
|
|
|
# negative int is unacceptable:
|
|
with self.assertRaises(ValueError):
|
|
TrainingArguments(torch_empty_cache_steps=-1)
|
|
|
|
# zero is unacceptable:
|
|
with self.assertRaises(ValueError):
|
|
TrainingArguments(torch_empty_cache_steps=0)
|
|
|
|
# positive int is acceptable:
|
|
args = TrainingArguments(torch_empty_cache_steps=1)
|
|
self.assertEqual(args.torch_empty_cache_steps, 1)
|