From 9e730689c3e8923b5e18981d194f5c662c2d4584 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:27:41 +0200 Subject: [PATCH] change XLA deprecated api (#37741) * deprecated api * fix --- src/transformers/trainer.py | 4 ++-- src/transformers/trainer_pt_utils.py | 6 +++--- src/transformers/trainer_utils.py | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 90957024bc6..568f2bfd7a5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -199,12 +199,12 @@ if is_datasets_available(): if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met + import torch_xla.runtime as xr from torch_xla import __version__ as XLA_VERSION IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) if IS_XLA_FSDPV2_POST_2_2: import torch_xla.distributed.spmd as xs - import torch_xla.runtime as xr else: IS_XLA_FSDPV2_POST_2_2 = False @@ -1042,7 +1042,7 @@ class Trainer: if self.args.use_legacy_prediction_loop: if is_torch_xla_available(): return SequentialDistributedSampler( - eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + eval_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal() ) elif is_sagemaker_mp_enabled(): return SequentialDistributedSampler( diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 30474daea6e..0fed4fb9041 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -52,7 +52,7 @@ if is_training_run_on_sagemaker(): logging.add_handler(StreamHandler(sys.stdout)) if is_torch_xla_available(): - import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr if is_torch_available(): from torch.optim.lr_scheduler import LRScheduler @@ -398,9 +398,9 @@ class SequentialDistributedSampler(Sampler): def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int): - if xm.xrt_world_size() <= 1: + if xr.world_size() <= 1: return RandomSampler(dataset) - return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + return DistributedSampler(dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal()) def nested_new_like(arrays, num_samples, padding_index=-100): diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 49feecf694d..7b2d5c34322 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -362,13 +362,13 @@ class HPSearchBackend(ExplicitEnum): def is_main_process(local_rank): """ - Whether or not the current process is the local process, based on `xm.get_ordinal()` (for TPUs) first, then on + Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on `local_rank`. """ if is_torch_xla_available(): - import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr - return xm.get_ordinal() == 0 + return xr.global_ordinal() == 0 return local_rank in [-1, 0] @@ -377,9 +377,9 @@ def total_processes_number(local_rank): Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs. """ if is_torch_xla_available(): - import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr - return xm.xrt_world_size() + return xr.world_size() elif local_rank != -1 and is_torch_available(): import torch