Fix IterableDataset with __len__ in Trainer (#8095)

This commit is contained in:
Jonathan Chang 2020-10-27 21:52:35 +08:00 committed by GitHub
parent d93acd6f13
commit 286dc19a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -384,7 +384,9 @@ class Trainer:
dataset.set_format(type=dataset.format["type"], columns=columns)
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if not isinstance(self.train_dataset, collections.abc.Sized):
if isinstance(self.train_dataset, torch.utils.data.IterableDataset) or not isinstance(
self.train_dataset, collections.abc.Sized
):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)