diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ba89d914d76..5baa3e1b51f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -66,6 +66,8 @@ if is_torch_available(): import torch import torch.distributed as dist + from .pytorch_utils import is_torch_greater_or_equal_than_2_0 + if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState from accelerate.utils import DistributedType @@ -1023,13 +1025,13 @@ class TrainingArguments: ) }, ) - dataloader_prefetch_factor: int = field( - default=None, + dataloader_prefetch_factor: Optional[int] = field( + default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2, metadata={ "help": ( "Number of batches loaded in advance by each worker. " "2 means there will be a total of 2 * num_workers batches prefetched across all workers. " - "Default is unset" + "Default is 2 for PyTorch < 2.0.0 and otherwise None." ) }, ) @@ -1807,7 +1809,11 @@ class TrainingArguments: if self.use_cpu: self.dataloader_pin_memory = False - if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None: + if ( + (not is_torch_available() or is_torch_greater_or_equal_than_2_0) + and self.dataloader_num_workers == 0 + and self.dataloader_prefetch_factor is not None + ): raise ValueError( "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e." " when --dataloader_num_workers > 1."