Trainer with Iterable Dataset (#7858)

* fix 5990

* accomodate iterable dataset without predefined length
* set it as 1 use case: provide max_steps, and NO num_epochs
* Is a merge of master and PR 5995

* fix trainer test under TF

* fix only for torch
* TF trainer untouched
* trainer tests are skipped when no torch

* address comments

* fix quality checks

* remove torch.dataset from test_trainer

* unnecessary inheritance
* RegressionDataset implements all needed methods __len__ and __getitem__

* fix quality checks

* restore RegressionDataset

* was wrongly under is_torch_available()
This commit is contained in:
Julien Rossi 2020-10-19 17:57:39 +02:00 committed by GitHub
parent 2422cda01b
commit a09fe140c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 123 additions and 42 deletions

View File

@ -16,7 +16,9 @@
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
""" """
import collections
import inspect import inspect
import math
import os import os
import re import re
import shutil import shutil
@ -283,6 +285,15 @@ class Trainer:
FutureWarning, FutureWarning,
) )
if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")
# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
if is_datasets_available(): if is_datasets_available():
if isinstance(train_dataset, datasets.Dataset): if isinstance(train_dataset, datasets.Dataset):
self._remove_unused_columns(self.train_dataset, description="training") self._remove_unused_columns(self.train_dataset, description="training")
@ -361,7 +372,7 @@ class Trainer:
dataset.set_format(type=dataset.format["type"], columns=columns) dataset.set_format(type=dataset.format["type"], columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset): if not isinstance(self.train_dataset, collections.abc.Sized):
return None return None
elif is_torch_tpu_available(): elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset) return get_tpu_sampler(self.train_dataset)
@ -376,7 +387,7 @@ class Trainer:
""" """
Returns the training :class:`~torch.utils.data.DataLoader`. Returns the training :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler
(adapted to distributed training if necessary) otherwise. (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
@ -395,9 +406,7 @@ class Trainer:
) )
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]: def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(eval_dataset, torch.utils.data.IterableDataset): if is_torch_tpu_available():
return None
elif is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset) return SequentialDistributedSampler(eval_dataset)
@ -408,19 +417,18 @@ class Trainer:
""" """
Returns the evaluation :class:`~torch.utils.data.DataLoader`. Returns the evaluation :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed. accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
if eval_dataset is None and self.eval_dataset is None: if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.") raise ValueError("Trainer: evaluation requires an eval_dataset.")
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation") self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
@ -438,17 +446,16 @@ class Trainer:
""" """
Returns the test :class:`~torch.utils.data.DataLoader`. Returns the test :class:`~torch.utils.data.DataLoader`.
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior. Subclass and override this method if you want to inject some custom behavior.
Args: Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`): test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
""" """
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): if not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
self._remove_unused_columns(test_dataset, description="test") self._remove_unused_columns(test_dataset, description="test")
test_sampler = self._get_eval_sampler(test_dataset) test_sampler = self._get_eval_sampler(test_dataset)
@ -494,6 +501,8 @@ class Trainer:
def num_examples(self, dataloader: DataLoader) -> int: def num_examples(self, dataloader: DataLoader) -> int:
""" """
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset. Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
""" """
return len(dataloader.dataset) return len(dataloader.dataset)
@ -579,19 +588,32 @@ class Trainer:
# Reinitializes optimizer and scheduler # Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None self.optimizer, self.lr_scheduler = None, None
# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
# Data loader and number of training steps # Data loader and number of training steps
train_dataloader = self.get_train_dataloader() train_dataloader = self.get_train_dataloader()
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) # Setting up training control variables:
if self.args.max_steps > 0: # number of training epochs: num_train_epochs
max_steps = self.args.max_steps # number of training steps per epoch: num_update_steps_per_epoch
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int( # total number of training steps to execute: max_steps
self.args.max_steps % num_update_steps_per_epoch > 0 if train_dataset_is_sized:
) num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs)
else: else:
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs) # see __init__. max_steps is set when the dataset has no __len__
num_train_epochs = self.args.num_train_epochs max_steps = self.args.max_steps
num_train_epochs = int(np.ceil(num_train_epochs)) num_train_epochs = 1
num_update_steps_per_epoch = max_steps
self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState() self.state = TrainerState()
@ -645,8 +667,15 @@ class Trainer:
* self.args.gradient_accumulation_steps * self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1) * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
) )
num_examples = (
self.num_examples(train_dataloader)
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_examples(train_dataloader)) logger.info(" Num examples = %d", num_examples)
logger.info(" Num Epochs = %d", num_train_epochs) logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size) logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
@ -703,6 +732,7 @@ class Trainer:
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = None self._past = None
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
@ -728,8 +758,8 @@ class Trainer:
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= self.args.gradient_accumulation_steps steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator) and (step + 1) == steps_in_epoch
): ):
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(self.optimizer) self.scaler.unscale_(self.optimizer)
@ -750,7 +780,7 @@ class Trainer:
self.lr_scheduler.step() self.lr_scheduler.step()
model.zero_grad() model.zero_grad()
self.state.global_step += 1 self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator) self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch) self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
@ -1207,11 +1237,15 @@ class Trainer:
Args: Args:
eval_dataset (:obj:`Dataset`, `optional`): eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
the :obj:`__len__` method.
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
""" """
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset) eval_dataloader = self.get_eval_dataloader(eval_dataset)
output = self.prediction_loop(eval_dataloader, description="Evaluation") output = self.prediction_loop(eval_dataloader, description="Evaluation")
@ -1234,7 +1268,7 @@ class Trainer:
Args: Args:
test_dataset (:obj:`Dataset`): test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
Returns: Returns:
`NamedTuple`: `NamedTuple`:
@ -1245,6 +1279,9 @@ class Trainer:
metrics (:obj:`Dict[str, float]`, `optional`): metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels). The potential dictionary of metrics (if the dataset contained labels).
""" """
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction") return self.prediction_loop(test_dataloader, description="Prediction")
@ -1264,6 +1301,8 @@ class Trainer:
) )
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only) return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = ( prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
) )

62
tests/test_trainer.py Executable file → Normal file
View File

@ -31,11 +31,14 @@ if is_torch_available():
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
from transformers import ( from transformers import (
AutoModelForMaskedLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset, LineByLineTextDataset,
PreTrainedModel, PreTrainedModel,
TextDataset,
Trainer, Trainer,
TrainerState, TrainerState,
) )
@ -83,15 +86,16 @@ class RegressionModelConfig(PretrainedConfig):
if is_torch_available(): if is_torch_available():
class SampleIterableDataset(IterableDataset): class SampleIterableDataset(IterableDataset):
def __init__(self, file_path): """
self.file_path = file_path Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
"""
def parse_file(self): def __init__(self, file_path, tokenizer):
f = open(self.file_path, "r") self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
return f.readlines()
def __iter__(self): def __iter__(self):
return iter(self.parse_file()) for i in range(len(self.ds)):
yield self.ds[i]
class RegressionModel(torch.nn.Module): class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0, double_output=False): def __init__(self, a=0, b=0, double_output=False):
@ -540,13 +544,51 @@ class TrainerIntegrationTest(unittest.TestCase):
self.assertEqual(len(dataset), 31) self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self): def test_trainer_iterable_dataset(self):
# Simulate Language Modeling with an IterableDataset, with no __len__ method
# Pick-up a tiny model, so it works on CPU
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
MODEL_ID = "sshleifer/tiny-distilbert-base-cased" MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID) model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True) train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset) training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
trainer.train()
loader = trainer.get_train_dataloader() loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader) self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
_ = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=train_dataset,
data_collator=data_collator,
)
# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.predict(train_dataset)
# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.evaluate(train_dataset)
def test_num_train_epochs_in_training(self): def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given. # len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.