From bb3c6426d844683d77f3e579ca1f8fef8be66a5e Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 20 May 2025 14:59:53 +0200 Subject: [PATCH] Make `train_dataset` attribute in `_get_train_sampler` optional (#38226) make it optional --- src/transformers/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8312f9558b6..c2a44b6ff08 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -972,7 +972,9 @@ class Trainer: ) return remove_columns_collator - def _get_train_sampler(self, train_dataset) -> Optional[torch.utils.data.Sampler]: + def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: + if train_dataset is None: + train_dataset = self.train_dataset if train_dataset is None or not has_length(train_dataset): return None