mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor (#29447)
* Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor dataloader_prefetch_factor was added to TrainingArguments in #28498 with the default value None, but versions of torch<2.0.0 do not accept None and will raise an error if num_workers == 0 and prefetch_factor != 2 * Add is_torch_available() check * Use is_torch_greater_or_equal_than_2_0 add back check for dataloader_prefetch_factor
This commit is contained in:
parent
b27aa206dd
commit
2890116ab7
@ -66,6 +66,8 @@ if is_torch_available():
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.state import AcceleratorState, PartialState
|
||||
from accelerate.utils import DistributedType
|
||||
@ -1023,13 +1025,13 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
dataloader_prefetch_factor: int = field(
|
||||
default=None,
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
|
||||
metadata={
|
||||
"help": (
|
||||
"Number of batches loaded in advance by each worker. "
|
||||
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
|
||||
"Default is unset"
|
||||
"Default is 2 for PyTorch < 2.0.0 and otherwise None."
|
||||
)
|
||||
},
|
||||
)
|
||||
@ -1807,7 +1809,11 @@ class TrainingArguments:
|
||||
if self.use_cpu:
|
||||
self.dataloader_pin_memory = False
|
||||
|
||||
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
|
||||
if (
|
||||
(not is_torch_available() or is_torch_greater_or_equal_than_2_0)
|
||||
and self.dataloader_num_workers == 0
|
||||
and self.dataloader_prefetch_factor is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
|
||||
" when --dataloader_num_workers > 1."
|
||||
|
Loading…
Reference in New Issue
Block a user