Make train_dataset attribute in _get_train_sampler optional (#38226)

make it optional
This commit is contained in:
Marc Sun 2025-05-20 14:59:53 +02:00 committed by GitHub
parent 2ad152f84c
commit bb3c6426d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -972,7 +972,9 @@ class Trainer:
) )
return remove_columns_collator return remove_columns_collator
def _get_train_sampler(self, train_dataset) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
if train_dataset is None:
train_dataset = self.train_dataset
if train_dataset is None or not has_length(train_dataset): if train_dataset is None or not has_length(train_dataset):
return None return None