diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 48179f0e0a4..6d68405ab35 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -48,6 +48,7 @@ from .utils import ( is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_mlu_available, + is_torch_mps_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_tf32_available, @@ -2178,6 +2179,8 @@ class TrainingArguments: ) if self.use_cpu: device = torch.device("cpu") + elif is_torch_mps_available(): + device = torch.device("mps") elif is_torch_xpu_available(): if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"): raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")