mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Fix type hint for train_dataset param of Trainer.__init__() to allow IterableDataset. Issue 29678 (#29738)
* Fixed typehint for train_dataset param in Trainer.__init__(). Added IterableDataset option. * make fixup
This commit is contained in:
parent
e68ff30419
commit
347916130c
@ -52,7 +52,7 @@ import torch.distributed as dist
|
||||
from huggingface_hub import ModelCard, create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
@ -353,7 +353,7 @@ class Trainer:
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
|
Loading…
Reference in New Issue
Block a user