mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Trainer: move Seq2SeqTrainer imports under the typing guard (#22401)
This commit is contained in:
parent
0e708178ed
commit
53155b520d
@ -14,39 +14,42 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .data.data_collator import DataCollator
|
||||
from .deepspeed import is_deepspeed_zero3_enabled
|
||||
from .generation.configuration_utils import GenerationConfig
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer import Trainer
|
||||
from .trainer_callback import TrainerCallback
|
||||
from .trainer_utils import EvalPrediction, PredictionOutput
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .data.data_collator import DataCollator
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from .trainer_callback import TrainerCallback
|
||||
from .trainer_utils import EvalPrediction, PredictionOutput
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqTrainer(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
model: Union["PreTrainedModel", nn.Module] = None,
|
||||
args: "TrainingArguments" = None,
|
||||
data_collator: Optional["DataCollator"] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
|
||||
compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None,
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
):
|
||||
@ -161,7 +164,7 @@ class Seq2SeqTrainer(Trainer):
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "test",
|
||||
**gen_kwargs,
|
||||
) -> PredictionOutput:
|
||||
) -> "PredictionOutput":
|
||||
"""
|
||||
Run prediction and returns predictions and potential metrics.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user