mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
parent
2933894985
commit
9e730689c3
@ -199,12 +199,12 @@ if is_datasets_available():
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.metrics as met
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla import __version__ as XLA_VERSION
|
||||
|
||||
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
|
||||
if IS_XLA_FSDPV2_POST_2_2:
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
else:
|
||||
IS_XLA_FSDPV2_POST_2_2 = False
|
||||
|
||||
@ -1042,7 +1042,7 @@ class Trainer:
|
||||
if self.args.use_legacy_prediction_loop:
|
||||
if is_torch_xla_available():
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
|
||||
eval_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal()
|
||||
)
|
||||
elif is_sagemaker_mp_enabled():
|
||||
return SequentialDistributedSampler(
|
||||
|
@ -52,7 +52,7 @@ if is_training_run_on_sagemaker():
|
||||
logging.add_handler(StreamHandler(sys.stdout))
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
if is_torch_available():
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
@ -398,9 +398,9 @@ class SequentialDistributedSampler(Sampler):
|
||||
|
||||
|
||||
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
|
||||
if xm.xrt_world_size() <= 1:
|
||||
if xr.world_size() <= 1:
|
||||
return RandomSampler(dataset)
|
||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
return DistributedSampler(dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal())
|
||||
|
||||
|
||||
def nested_new_like(arrays, num_samples, padding_index=-100):
|
||||
|
@ -362,13 +362,13 @@ class HPSearchBackend(ExplicitEnum):
|
||||
|
||||
def is_main_process(local_rank):
|
||||
"""
|
||||
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
|
||||
Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
|
||||
`local_rank`.
|
||||
"""
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
return xm.get_ordinal() == 0
|
||||
return xr.global_ordinal() == 0
|
||||
return local_rank in [-1, 0]
|
||||
|
||||
|
||||
@ -377,9 +377,9 @@ def total_processes_number(local_rank):
|
||||
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
|
||||
"""
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
return xm.xrt_world_size()
|
||||
return xr.world_size()
|
||||
elif local_rank != -1 and is_torch_available():
|
||||
import torch
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user