mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[tests] make test_model_parallelism
device-agnostic (#30844)
* enable on xpu * fix style * add comment and mps
This commit is contained in:
parent
42d8dd8716
commit
04c7c176d7
@ -76,6 +76,7 @@ from transformers.testing_utils import (
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
@ -3009,8 +3010,11 @@ class ModelTesterMixin:
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
self.assertEqual(param.device, torch.device("meta"))
|
||||
elif param_device in ["mps"]:
|
||||
self.assertEqual(param.device, torch.device("mps"))
|
||||
else:
|
||||
self.assertEqual(param.device, torch.device(param_device))
|
||||
# when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
|
||||
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@ -3129,7 +3133,7 @@ class ModelTesterMixin:
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -3155,7 +3159,6 @@ class ModelTesterMixin:
|
||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
Loading…
Reference in New Issue
Block a user