From b30879fe0c8138b1718491e25df22f3c2dfdc7e6 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 24 Aug 2020 09:22:03 -0400 Subject: [PATCH] Don't reset the dataset type + plug for rm unused columns (#6683) * Don't reset the type of the dataset * Formatting * Update trainer.py Co-authored-by: Teven --- src/transformers/trainer.py | 7 ++++++- src/transformers/training_args.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a03ac23facf..06846c833a6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -244,6 +244,8 @@ class Trainer: self.scaler = torch.cuda.amp.GradScaler() def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return # Inspect model forward signature to keep only the arguments it accepts. signature = inspect.signature(self.model.forward) signature_columns = list(signature.parameters.keys()) @@ -255,7 +257,10 @@ class Trainer: logger.info( f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." ) - dataset.set_format(columns=columns) + ds_type = dataset.format["type"] + if ds_type == "python": + ds_type = None + dataset.set_format(type=ds_type, columns=columns) def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]: if isinstance(self.train_dataset, torch.utils.data.IterableDataset): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 649001362cc..3c543414ff6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -114,6 +114,11 @@ class TrainingArguments: at the next training step under the keyword argument ``mems``. run_name (:obj:`str`, `optional`): A descriptor for the run. Notably used for wandb logging. + remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`): + If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model + forward method. + + (Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.) """ output_dir: str = field( @@ -234,6 +239,10 @@ class TrainingArguments: default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."} ) + remove_unused_columns: Optional[bool] = field( + default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} + ) + @property def train_batch_size(self) -> int: """