Trainer: move Seq2SeqTrainer imports under the typing guard (#22401)

This commit is contained in:
Joao Gante 2023-03-27 16:39:26 +01:00 committed by GitHub
parent 0e708178ed
commit 53155b520d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,39 +14,42 @@
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import Dataset
from .data.data_collator import DataCollator
from .deepspeed import is_deepspeed_zero3_enabled
from .generation.configuration_utils import GenerationConfig
from .modeling_utils import PreTrainedModel
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer import Trainer
from .trainer_callback import TrainerCallback
from .trainer_utils import EvalPrediction, PredictionOutput
from .training_args import TrainingArguments
from .utils import logging
if TYPE_CHECKING:
from .data.data_collator import DataCollator
from .modeling_utils import PreTrainedModel
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import TrainerCallback
from .trainer_utils import EvalPrediction, PredictionOutput
from .training_args import TrainingArguments
logger = logging.get_logger(__name__)
class Seq2SeqTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
model: Union["PreTrainedModel", nn.Module] = None,
args: "TrainingArguments" = None,
data_collator: Optional["DataCollator"] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
callbacks: Optional[List["TrainerCallback"]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
@ -161,7 +164,7 @@ class Seq2SeqTrainer(Trainer):
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test",
**gen_kwargs,
) -> PredictionOutput:
) -> "PredictionOutput":
"""
Run prediction and returns predictions and potential metrics.