Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor (#29447)

* Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor

dataloader_prefetch_factor was added to TrainingArguments in #28498 with the default value None, but  versions of torch<2.0.0 do not accept None and will raise an error if num_workers == 0 and prefetch_factor != 2

* Add is_torch_available() check

* Use is_torch_greater_or_equal_than_2_0

add back check for dataloader_prefetch_factor
This commit is contained in:
Matthew Hoffman 2024-03-06 01:44:08 -08:00 committed by GitHub
parent b27aa206dd
commit 2890116ab7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -66,6 +66,8 @@ if is_torch_available():
import torch
import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType
@ -1023,13 +1025,13 @@ class TrainingArguments:
)
},
)
dataloader_prefetch_factor: int = field(
default=None,
dataloader_prefetch_factor: Optional[int] = field(
default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
"Default is 2 for PyTorch < 2.0.0 and otherwise None."
)
},
)
@ -1807,7 +1809,11 @@ class TrainingArguments:
if self.use_cpu:
self.dataloader_pin_memory = False
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
if (
(not is_torch_available() or is_torch_greater_or_equal_than_2_0)
and self.dataloader_num_workers == 0
and self.dataloader_prefetch_factor is not None
):
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."