From ae9dd02ee1a8627d26be32202202b8081e9855a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Storhaug?= Date: Mon, 8 Jul 2024 13:49:30 +0200 Subject: [PATCH] Fix incorrect accelerator device handling for MPS in `TrainingArguments` (#31812) * Fix wrong acclerator device setup when using MPS * More robust TrainingArguments MPS handling * Update training_args.py * Cleanup --- src/transformers/training_args.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 48179f0e0a4..6d68405ab35 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -48,6 +48,7 @@ from .utils import ( is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_mlu_available, + is_torch_mps_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_tf32_available, @@ -2178,6 +2179,8 @@ class TrainingArguments: ) if self.use_cpu: device = torch.device("cpu") + elif is_torch_mps_available(): + device = torch.device("mps") elif is_torch_xpu_available(): if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")