From 573942d96a3f371a9ff24e1878826f20621efc19 Mon Sep 17 00:00:00 2001 From: Ben Lewis Date: Tue, 8 Oct 2024 23:43:30 +0300 Subject: [PATCH] Fix `trainer_seq2seq.py`'s `__init__` type annotations (#34021) * Fix `trainer_seq2seq.py`'s `__init__` type annotations * Update src/transformers/trainer_seq2seq.py Co-authored-by: Lysandre Debut * Fix issue pointed out by `muellerzr` --------- Co-authored-by: Lysandre Debut --- src/transformers/trainer_seq2seq.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index adbf89bb21a..03fc484fd80 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -24,11 +24,16 @@ from torch.utils.data import Dataset from .generation.configuration_utils import GenerationConfig from .integrations.deepspeed import is_deepspeed_zero3_enabled from .trainer import Trainer -from .utils import logging +from .utils import is_datasets_available, logging from .utils.deprecation import deprecate_kwarg +if is_datasets_available(): + import datasets + if TYPE_CHECKING: + from torch.utils.data import IterableDataset + from .data.data_collator import DataCollator from .feature_extraction_utils import FeatureExtractionMixin from .image_processing_utils import BaseImageProcessor @@ -50,7 +55,7 @@ class Seq2SeqTrainer(Trainer): model: Union["PreTrainedModel", nn.Module] = None, args: "TrainingArguments" = None, data_collator: Optional["DataCollator"] = None, - train_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, "IterableDataset", "datasets.Dataset"]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, processing_class: Optional[ Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]