change XLA deprecated api (#37741)

* deprecated api

* fix
This commit is contained in:
Marc Sun 2025-04-28 16:27:41 +02:00 committed by GitHub
parent 2933894985
commit 9e730689c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 10 deletions

View File

@ -199,12 +199,12 @@ if is_datasets_available():
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.runtime as xr
from torch_xla import __version__ as XLA_VERSION
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
else:
IS_XLA_FSDPV2_POST_2_2 = False
@ -1042,7 +1042,7 @@ class Trainer:
if self.args.use_legacy_prediction_loop:
if is_torch_xla_available():
return SequentialDistributedSampler(
eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
eval_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal()
)
elif is_sagemaker_mp_enabled():
return SequentialDistributedSampler(

View File

@ -52,7 +52,7 @@ if is_training_run_on_sagemaker():
logging.add_handler(StreamHandler(sys.stdout))
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
if is_torch_available():
from torch.optim.lr_scheduler import LRScheduler
@ -398,9 +398,9 @@ class SequentialDistributedSampler(Sampler):
def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
if xm.xrt_world_size() <= 1:
if xr.world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
return DistributedSampler(dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal())
def nested_new_like(arrays, num_samples, padding_index=-100):

View File

@ -362,13 +362,13 @@ class HPSearchBackend(ExplicitEnum):
def is_main_process(local_rank):
"""
Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on
Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
`local_rank`.
"""
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
return xm.get_ordinal() == 0
return xr.global_ordinal() == 0
return local_rank in [-1, 0]
@ -377,9 +377,9 @@ def total_processes_number(local_rank):
Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
"""
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
return xm.xrt_world_size()
return xr.world_size()
elif local_rank != -1 and is_torch_available():
import torch