Fix trainer_seq2seq.py's __init__ type annotations (#34021)

* Fix `trainer_seq2seq.py`'s `__init__` type annotations

* Update src/transformers/trainer_seq2seq.py

Co-authored-by: Lysandre Debut <hi@lysand.re>

* Fix issue pointed out by `muellerzr`

---------

Co-authored-by: Lysandre Debut <hi@lysand.re>
This commit is contained in:
Ben Lewis 2024-10-08 23:43:30 +03:00 committed by GitHub
parent 04b4e441dc
commit 573942d96a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,11 +24,16 @@ from torch.utils.data import Dataset
from .generation.configuration_utils import GenerationConfig
from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .trainer import Trainer
from .utils import logging
from .utils import is_datasets_available, logging
from .utils.deprecation import deprecate_kwarg
if is_datasets_available():
import datasets
if TYPE_CHECKING:
from torch.utils.data import IterableDataset
from .data.data_collator import DataCollator
from .feature_extraction_utils import FeatureExtractionMixin
from .image_processing_utils import BaseImageProcessor
@ -50,7 +55,7 @@ class Seq2SeqTrainer(Trainer):
model: Union["PreTrainedModel", nn.Module] = None,
args: "TrainingArguments" = None,
data_collator: Optional["DataCollator"] = None,
train_dataset: Optional[Dataset] = None,
train_dataset: Optional[Union[Dataset, "IterableDataset", "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
processing_class: Optional[
Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]