mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Avoid looping when data exhausted (#14413)
* stop training when a finite IterableDataset is exhausted when using an iterable dataset num_epochs is set to sys.maxsize to make sure all data is consumed likewise we want to set max_steps high enough but still stop when all data is consumed (cherry picked from commit 6f0e1d6363153da9051e93acffe1cbab3a3f3b12) * fix typo flase -> false * add test for stopping training on exhausted finite iterable dataset * remove redundant gradient_accumulation_steps * run make style reformat training_args docstring
This commit is contained in:
parent
3e8d17e66d
commit
a33168aa78
@ -1287,6 +1287,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
|
||||||
|
|
||||||
|
step = -1
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
for step, inputs in enumerate(epoch_iterator):
|
||||||
|
|
||||||
# Skip past any already trained steps if resuming training
|
# Skip past any already trained steps if resuming training
|
||||||
@ -1386,6 +1387,13 @@ class Trainer:
|
|||||||
|
|
||||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||||
break
|
break
|
||||||
|
if step < 0:
|
||||||
|
logger.warning(
|
||||||
|
f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
|
||||||
|
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
|
||||||
|
f" num_steps ({max_steps}) higher than the number of available samples."
|
||||||
|
)
|
||||||
|
self.control.should_training_stop = True
|
||||||
|
|
||||||
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
|
||||||
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
|
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
|
||||||
|
@ -141,7 +141,8 @@ class TrainingArguments:
|
|||||||
the last epoch before stopping training).
|
the last epoch before stopping training).
|
||||||
max_steps (:obj:`int`, `optional`, defaults to -1):
|
max_steps (:obj:`int`, `optional`, defaults to -1):
|
||||||
If set to a positive number, the total number of training steps to perform. Overrides
|
If set to a positive number, the total number of training steps to perform. Overrides
|
||||||
:obj:`num_train_epochs`.
|
:obj:`num_train_epochs`. In case of using a finite iterable dataset the training may stop before reaching
|
||||||
|
the set number of steps when all data is exhausted
|
||||||
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
||||||
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
||||||
values.
|
values.
|
||||||
|
@ -172,6 +172,16 @@ if is_torch_available():
|
|||||||
for i in range(len(self.dataset)):
|
for i in range(len(self.dataset)):
|
||||||
yield self.dataset[i]
|
yield self.dataset[i]
|
||||||
|
|
||||||
|
class FiniteIterableDataset(SampleIterableDataset):
|
||||||
|
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
|
||||||
|
super().__init__(a, b, length, seed, label_names)
|
||||||
|
self.current_sample = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while self.current_sample < len(self.dataset):
|
||||||
|
yield self.dataset[self.current_sample]
|
||||||
|
self.current_sample += 1
|
||||||
|
|
||||||
class RegressionModel(nn.Module):
|
class RegressionModel(nn.Module):
|
||||||
def __init__(self, a=0, b=0, double_output=False):
|
def __init__(self, a=0, b=0, double_output=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -856,7 +866,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertAlmostEqual(b, b1, delta=1e-8)
|
self.assertAlmostEqual(b, b1, delta=1e-8)
|
||||||
|
|
||||||
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
# regression for this issue: https://github.com/huggingface/transformers/issues/12970
|
||||||
def test_training_with_resume_from_checkpoint_flase(self):
|
def test_training_with_resume_from_checkpoint_false(self):
|
||||||
train_dataset = RegressionDataset(length=128)
|
train_dataset = RegressionDataset(length=128)
|
||||||
eval_dataset = RegressionDataset()
|
eval_dataset = RegressionDataset()
|
||||||
|
|
||||||
@ -1058,6 +1068,26 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||||
|
|
||||||
|
def test_training_finite_iterable_dataset(self):
|
||||||
|
config = RegressionModelConfig()
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
|
||||||
|
batch_size = 1
|
||||||
|
num_samples = 10
|
||||||
|
|
||||||
|
available_steps = num_samples // batch_size
|
||||||
|
|
||||||
|
data = FiniteIterableDataset(length=num_samples)
|
||||||
|
train_args = TrainingArguments(
|
||||||
|
".",
|
||||||
|
max_steps=available_steps + 1, # set a higher number than actually available
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
)
|
||||||
|
trainer = Trainer(model, train_dataset=data, args=train_args)
|
||||||
|
with self.assertLogs("transformers.trainer", level="WARNING") as logs:
|
||||||
|
trainer.train()
|
||||||
|
self.assertIn(f"stopping training at step {available_steps}!", logs.output[0])
|
||||||
|
|
||||||
def test_evaluation_iterable_dataset(self):
|
def test_evaluation_iterable_dataset(self):
|
||||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
model = RegressionPreTrainedModel(config)
|
model = RegressionPreTrainedModel(config)
|
||||||
|
Loading…
Reference in New Issue
Block a user