mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
[tests] make test_model_parallelism
device-agnostic (#30844)
* enable on xpu * fix style * add comment and mps
This commit is contained in:
parent
42d8dd8716
commit
04c7c176d7
@ -76,6 +76,7 @@ from transformers.testing_utils import (
|
|||||||
require_safetensors,
|
require_safetensors,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
|
require_torch_multi_accelerator,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
slow,
|
slow,
|
||||||
@ -3009,8 +3010,11 @@ class ModelTesterMixin:
|
|||||||
param_device = device_map[param_name]
|
param_device = device_map[param_name]
|
||||||
if param_device in ["cpu", "disk"]:
|
if param_device in ["cpu", "disk"]:
|
||||||
self.assertEqual(param.device, torch.device("meta"))
|
self.assertEqual(param.device, torch.device("meta"))
|
||||||
|
elif param_device in ["mps"]:
|
||||||
|
self.assertEqual(param.device, torch.device("mps"))
|
||||||
else:
|
else:
|
||||||
self.assertEqual(param.device, torch.device(param_device))
|
# when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu
|
||||||
|
self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}"))
|
||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@mark.accelerate_tests
|
@mark.accelerate_tests
|
||||||
@ -3129,7 +3133,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
@require_accelerate
|
@require_accelerate
|
||||||
@mark.accelerate_tests
|
@mark.accelerate_tests
|
||||||
@require_torch_multi_gpu
|
@require_torch_multi_accelerator
|
||||||
def test_model_parallelism(self):
|
def test_model_parallelism(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@ -3155,7 +3159,6 @@ class ModelTesterMixin:
|
|||||||
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||||
# Making sure part of the model will actually end up offloaded
|
# Making sure part of the model will actually end up offloaded
|
||||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
|
||||||
|
|
||||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user