mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Trainer with Iterable Dataset (#7858)
* fix 5990 * accomodate iterable dataset without predefined length * set it as 1 use case: provide max_steps, and NO num_epochs * Is a merge of master and PR 5995 * fix trainer test under TF * fix only for torch * TF trainer untouched * trainer tests are skipped when no torch * address comments * fix quality checks * remove torch.dataset from test_trainer * unnecessary inheritance * RegressionDataset implements all needed methods __len__ and __getitem__ * fix quality checks * restore RegressionDataset * was wrongly under is_torch_available()
This commit is contained in:
parent
2422cda01b
commit
a09fe140c1
@ -16,7 +16,9 @@
|
|||||||
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@ -283,6 +285,15 @@ class Trainer:
|
|||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.max_steps > 0:
|
||||||
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
|
# Enforce rules on using datasets with no __len__
|
||||||
|
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
|
||||||
|
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
|
||||||
|
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||||
|
raise ValueError("eval_dataset must implement __len__")
|
||||||
|
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
if isinstance(train_dataset, datasets.Dataset):
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(self.train_dataset, description="training")
|
self._remove_unused_columns(self.train_dataset, description="training")
|
||||||
@ -361,7 +372,7 @@ class Trainer:
|
|||||||
dataset.set_format(type=dataset.format["type"], columns=columns)
|
dataset.set_format(type=dataset.format["type"], columns=columns)
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||||
return None
|
return None
|
||||||
elif is_torch_tpu_available():
|
elif is_torch_tpu_available():
|
||||||
return get_tpu_sampler(self.train_dataset)
|
return get_tpu_sampler(self.train_dataset)
|
||||||
@ -376,7 +387,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
Returns the training :class:`~torch.utils.data.DataLoader`.
|
Returns the training :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
|
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler
|
||||||
(adapted to distributed training if necessary) otherwise.
|
(adapted to distributed training if necessary) otherwise.
|
||||||
|
|
||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
@ -395,9 +406,7 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||||
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
if is_torch_tpu_available():
|
||||||
return None
|
|
||||||
elif is_torch_tpu_available():
|
|
||||||
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
elif self.args.local_rank != -1:
|
elif self.args.local_rank != -1:
|
||||||
return SequentialDistributedSampler(eval_dataset)
|
return SequentialDistributedSampler(eval_dataset)
|
||||||
@ -408,19 +417,18 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
|
||||||
sampler (adapted to distributed training if necessary) otherwise.
|
|
||||||
|
|
||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
||||||
accepted by the ``model.forward()`` method are automatically removed.
|
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||||
"""
|
"""
|
||||||
if eval_dataset is None and self.eval_dataset is None:
|
if eval_dataset is None and self.eval_dataset is None:
|
||||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||||
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||||
|
raise ValueError("eval_dataset must implement __len__")
|
||||||
|
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(eval_dataset, description="evaluation")
|
self._remove_unused_columns(eval_dataset, description="evaluation")
|
||||||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
@ -438,17 +446,16 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
Returns the test :class:`~torch.utils.data.DataLoader`.
|
Returns the test :class:`~torch.utils.data.DataLoader`.
|
||||||
|
|
||||||
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
|
||||||
sampler (adapted to distributed training if necessary) otherwise.
|
|
||||||
|
|
||||||
Subclass and override this method if you want to inject some custom behavior.
|
Subclass and override this method if you want to inject some custom behavior.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
||||||
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
|
||||||
"""
|
"""
|
||||||
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
if not isinstance(test_dataset, collections.abc.Sized):
|
||||||
|
raise ValueError("test_dataset must implement __len__")
|
||||||
|
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
||||||
self._remove_unused_columns(test_dataset, description="test")
|
self._remove_unused_columns(test_dataset, description="test")
|
||||||
test_sampler = self._get_eval_sampler(test_dataset)
|
test_sampler = self._get_eval_sampler(test_dataset)
|
||||||
|
|
||||||
@ -494,6 +501,8 @@ class Trainer:
|
|||||||
def num_examples(self, dataloader: DataLoader) -> int:
|
def num_examples(self, dataloader: DataLoader) -> int:
|
||||||
"""
|
"""
|
||||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||||
|
|
||||||
|
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
|
||||||
"""
|
"""
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
|
||||||
@ -579,19 +588,32 @@ class Trainer:
|
|||||||
# Reinitializes optimizer and scheduler
|
# Reinitializes optimizer and scheduler
|
||||||
self.optimizer, self.lr_scheduler = None, None
|
self.optimizer, self.lr_scheduler = None, None
|
||||||
|
|
||||||
|
# Keeping track whether we can can len() on the dataset or not
|
||||||
|
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
|
||||||
|
|
||||||
# Data loader and number of training steps
|
# Data loader and number of training steps
|
||||||
train_dataloader = self.get_train_dataloader()
|
train_dataloader = self.get_train_dataloader()
|
||||||
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
|
||||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
# Setting up training control variables:
|
||||||
if self.args.max_steps > 0:
|
# number of training epochs: num_train_epochs
|
||||||
max_steps = self.args.max_steps
|
# number of training steps per epoch: num_update_steps_per_epoch
|
||||||
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
# total number of training steps to execute: max_steps
|
||||||
self.args.max_steps % num_update_steps_per_epoch > 0
|
if train_dataset_is_sized:
|
||||||
)
|
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||||
|
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||||
|
if self.args.max_steps > 0:
|
||||||
|
max_steps = self.args.max_steps
|
||||||
|
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
||||||
|
self.args.max_steps % num_update_steps_per_epoch > 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
|
||||||
|
num_train_epochs = math.ceil(self.args.num_train_epochs)
|
||||||
else:
|
else:
|
||||||
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
# see __init__. max_steps is set when the dataset has no __len__
|
||||||
num_train_epochs = self.args.num_train_epochs
|
max_steps = self.args.max_steps
|
||||||
num_train_epochs = int(np.ceil(num_train_epochs))
|
num_train_epochs = 1
|
||||||
|
num_update_steps_per_epoch = max_steps
|
||||||
|
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
@ -645,8 +667,15 @@ class Trainer:
|
|||||||
* self.args.gradient_accumulation_steps
|
* self.args.gradient_accumulation_steps
|
||||||
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_examples = (
|
||||||
|
self.num_examples(train_dataloader)
|
||||||
|
if train_dataset_is_sized
|
||||||
|
else total_train_batch_size * self.args.max_steps
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
|
logger.info(" Num examples = %d", num_examples)
|
||||||
logger.info(" Num Epochs = %d", num_train_epochs)
|
logger.info(" Num Epochs = %d", num_train_epochs)
|
||||||
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
||||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
||||||
@ -703,6 +732,7 @@ class Trainer:
|
|||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
self._past = None
|
self._past = None
|
||||||
|
|
||||||
|
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
|
||||||
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step, inputs in enumerate(epoch_iterator):
|
for step, inputs in enumerate(epoch_iterator):
|
||||||
@ -728,8 +758,8 @@ class Trainer:
|
|||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
len(epoch_iterator) <= self.args.gradient_accumulation_steps
|
steps_in_epoch <= self.args.gradient_accumulation_steps
|
||||||
and (step + 1) == len(epoch_iterator)
|
and (step + 1) == steps_in_epoch
|
||||||
):
|
):
|
||||||
if self.args.fp16 and _use_native_amp:
|
if self.args.fp16 and _use_native_amp:
|
||||||
self.scaler.unscale_(self.optimizer)
|
self.scaler.unscale_(self.optimizer)
|
||||||
@ -750,7 +780,7 @@ class Trainer:
|
|||||||
self.lr_scheduler.step()
|
self.lr_scheduler.step()
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
|
self.state.epoch = epoch + (step + 1) / steps_in_epoch
|
||||||
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
|
||||||
@ -1207,11 +1237,15 @@ class Trainer:
|
|||||||
Args:
|
Args:
|
||||||
eval_dataset (:obj:`Dataset`, `optional`):
|
eval_dataset (:obj:`Dataset`, `optional`):
|
||||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
||||||
columns not accepted by the ``model.forward()`` method are automatically removed.
|
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
|
||||||
|
the :obj:`__len__` method.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
||||||
"""
|
"""
|
||||||
|
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
|
||||||
|
raise ValueError("eval_dataset must implement __len__")
|
||||||
|
|
||||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
||||||
@ -1234,7 +1268,7 @@ class Trainer:
|
|||||||
Args:
|
Args:
|
||||||
test_dataset (:obj:`Dataset`):
|
test_dataset (:obj:`Dataset`):
|
||||||
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed.
|
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`NamedTuple`:
|
`NamedTuple`:
|
||||||
@ -1245,6 +1279,9 @@ class Trainer:
|
|||||||
metrics (:obj:`Dict[str, float]`, `optional`):
|
metrics (:obj:`Dict[str, float]`, `optional`):
|
||||||
The potential dictionary of metrics (if the dataset contained labels).
|
The potential dictionary of metrics (if the dataset contained labels).
|
||||||
"""
|
"""
|
||||||
|
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
|
||||||
|
raise ValueError("test_dataset must implement __len__")
|
||||||
|
|
||||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||||
|
|
||||||
return self.prediction_loop(test_dataloader, description="Prediction")
|
return self.prediction_loop(test_dataloader, description="Prediction")
|
||||||
@ -1264,6 +1301,8 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
||||||
|
|
||||||
|
if not isinstance(dataloader.dataset, collections.abc.Sized):
|
||||||
|
raise ValueError("dataset must implement __len__")
|
||||||
prediction_loss_only = (
|
prediction_loss_only = (
|
||||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||||
)
|
)
|
||||||
|
62
tests/test_trainer.py
Executable file → Normal file
62
tests/test_trainer.py
Executable file → Normal file
@ -31,11 +31,14 @@ if is_torch_available():
|
|||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoModelForMaskedLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
DataCollatorForLanguageModeling,
|
||||||
GlueDataset,
|
GlueDataset,
|
||||||
GlueDataTrainingArguments,
|
GlueDataTrainingArguments,
|
||||||
LineByLineTextDataset,
|
LineByLineTextDataset,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
|
TextDataset,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
@ -83,15 +86,16 @@ class RegressionModelConfig(PretrainedConfig):
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
|
||||||
class SampleIterableDataset(IterableDataset):
|
class SampleIterableDataset(IterableDataset):
|
||||||
def __init__(self, file_path):
|
"""
|
||||||
self.file_path = file_path
|
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
|
||||||
|
"""
|
||||||
|
|
||||||
def parse_file(self):
|
def __init__(self, file_path, tokenizer):
|
||||||
f = open(self.file_path, "r")
|
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)
|
||||||
return f.readlines()
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.parse_file())
|
for i in range(len(self.ds)):
|
||||||
|
yield self.ds[i]
|
||||||
|
|
||||||
class RegressionModel(torch.nn.Module):
|
class RegressionModel(torch.nn.Module):
|
||||||
def __init__(self, a=0, b=0, double_output=False):
|
def __init__(self, a=0, b=0, double_output=False):
|
||||||
@ -540,13 +544,51 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(dataset), 31)
|
self.assertEqual(len(dataset), 31)
|
||||||
|
|
||||||
def test_trainer_iterable_dataset(self):
|
def test_trainer_iterable_dataset(self):
|
||||||
|
# Simulate Language Modeling with an IterableDataset, with no __len__ method
|
||||||
|
# Pick-up a tiny model, so it works on CPU
|
||||||
|
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
|
||||||
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
|
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
|
||||||
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
|
||||||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||||
|
|
||||||
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||||
|
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
loader = trainer.get_train_dataloader()
|
loader = trainer.get_train_dataloader()
|
||||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||||
|
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||||
|
|
||||||
|
# Exception if giving iterable dataset and no max_steps
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||||
|
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
|
||||||
|
|
||||||
|
# Exception if eval_dataset is iterable in __init__
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
|
||||||
|
_ = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=train_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exception if predicting with iterable dataset
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||||
|
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
|
||||||
|
trainer.predict(train_dataset)
|
||||||
|
|
||||||
|
# Exception if evaluating with iterable dataset
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||||
|
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
|
||||||
|
trainer.evaluate(train_dataset)
|
||||||
|
|
||||||
def test_num_train_epochs_in_training(self):
|
def test_num_train_epochs_in_training(self):
|
||||||
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
||||||
|
Loading…
Reference in New Issue
Block a user