mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix IterableDataset with __len__ in Trainer (#8095)
This commit is contained in:
parent
d93acd6f13
commit
286dc19a4f
@ -384,7 +384,9 @@ class Trainer:
|
||||
dataset.set_format(type=dataset.format["type"], columns=columns)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if not isinstance(self.train_dataset, collections.abc.Sized):
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
|
||||
self.train_dataset, collections.abc.Sized
|
||||
):
|
||||
return None
|
||||
elif is_torch_tpu_available():
|
||||
return get_tpu_sampler(self.train_dataset)
|
||||
|
Loading…
Reference in New Issue
Block a user