[tests] make test_model_parallelism device-agnostic (#30844)

* enable on xpu

* fix style

* add comment and mps
This commit is contained in:
Fanli Lin 2024-05-24 18:51:51 +08:00 committed by GitHub
parent 42d8dd8716
commit 04c7c176d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -76,6 +76,7 @@ from transformers.testing_utils import (
require_safetensors,
require_torch,
require_torch_gpu,
require_torch_multi_accelerator,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
@ -3009,8 +3010,11 @@ class ModelTesterMixin:
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta"))
elif param_device in ["mps"]:
self.assertEqual(param.device, torch.device("mps"))
else:
self.assertEqual(param.device, torch.device(param_device))
# when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
@require_accelerate
@mark.accelerate_tests
@ -3129,7 +3133,7 @@ class ModelTesterMixin:
@require_accelerate
@mark.accelerate_tests
@require_torch_multi_gpu
@require_torch_multi_accelerator
def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@ -3155,7 +3159,6 @@ class ModelTesterMixin:
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)