enable xpu in test_trainer (#37774)

* enable xpu in test_trainer

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* enhance _device_agnostic_dispatch to cover value

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* add default values for torch not available case

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-05-06 23:13:35 +08:00 committed by GitHub
parent 7db5d5b9ea
commit 5534b80b7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 31 deletions

View File

@ -2946,10 +2946,10 @@ def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable],
fn = dispatch_table[device]
# Some device agnostic functions return values. Need to guard against `None`
# instead at user level.
if fn is None:
return None
# Some device agnostic functions return values or None, will return then directly.
if not callable(fn):
return fn
return fn(*args, **kwargs)
@ -2971,10 +2971,29 @@ if is_torch_available():
"cpu": lambda: 0,
"default": lambda: 1,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"cpu": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"cpu": 0,
"default": 0,
}
BACKEND_MEMORY_ALLOCATED = {
"cuda": torch.cuda.memory_allocated,
"cpu": 0,
"default": 0,
}
else:
BACKEND_MANUAL_SEED = {"default": None}
BACKEND_EMPTY_CACHE = {"default": None}
BACKEND_DEVICE_COUNT = {"default": lambda: 0}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
BACKEND_MEMORY_ALLOCATED = {"default": 0}
if is_torch_hpu_available():
BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
@ -2994,6 +3013,9 @@ if is_torch_xpu_available():
BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
if is_torch_xla_available():
BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
@ -3013,6 +3035,18 @@ def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
def backend_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
if is_torch_available():
# If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
# into device to function mappings.

View File

@ -62,6 +62,10 @@ from transformers.testing_utils import (
TemporaryHubRepo,
TestCasePlus,
backend_device_count,
backend_empty_cache,
backend_max_memory_allocated,
backend_memory_allocated,
backend_reset_max_memory_allocated,
evaluate_side_effect_factory,
execute_subprocess_async,
get_gpu_count,
@ -78,7 +82,6 @@ from transformers.testing_utils import (
require_liger_kernel,
require_lomo,
require_non_hpu,
require_non_xpu,
require_optuna,
require_peft,
require_ray,
@ -245,18 +248,18 @@ def bytes2megabytes(x):
class TorchTracemalloc:
def __enter__(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
self.begin = torch.cuda.memory_allocated()
if torch_device in ["cuda", "xpu"]:
backend_empty_cache(torch_device)
backend_reset_max_memory_allocated(torch_device) # reset the peak gauge to zero
self.begin = backend_memory_allocated(torch_device)
return self
def __exit__(self, *exc):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.end = torch.cuda.memory_allocated()
self.peak = torch.cuda.max_memory_allocated()
if torch_device in ["cuda", "xpu"]:
backend_empty_cache(torch_device)
self.end = backend_memory_allocated(torch_device)
self.peak = backend_max_memory_allocated(torch_device)
self.used = bytes2megabytes(self.end - self.begin)
self.peaked = bytes2megabytes(self.peak - self.begin)
@ -1246,7 +1249,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
# will add more specific tests once there are some bugs to fix
@require_non_xpu
@require_torch_gpu
@require_torch_tf32
def test_tf32(self):
@ -1838,7 +1840,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_lomo
@require_torch_gpu
@require_torch_accelerator
def test_lomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -1861,7 +1863,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
@require_lomo
@require_torch_gpu
@require_torch_accelerator
def test_adalomo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2027,7 +2029,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(is_regex)
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2048,7 +2050,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2070,7 +2072,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2091,7 +2093,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2113,7 +2115,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adamw_8bit(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2134,7 +2136,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adafactor(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
@ -2166,7 +2168,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adafactor_attention_only(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
@ -2197,7 +2199,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(lower_bound_pm < galore_peak_memory)
@require_galore_torch
@require_torch_gpu
@require_torch_accelerator
def test_galore_adafactor_all_linear(self):
# These are the intervals of the peak memory usage of training such a tiny model
# if the peak memory goes outside that range, then we know there might be a bug somewhere
@ -2305,7 +2307,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2326,7 +2328,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_extra_args(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2348,7 +2350,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_layerwise(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2369,7 +2371,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_layerwise_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2391,7 +2393,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
_ = trainer.train()
@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_lr_display_without_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -2416,7 +2418,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
@require_apollo_torch
@require_torch_gpu
@require_torch_accelerator
def test_apollo_lr_display_with_scheduler(self):
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
@ -3995,7 +3997,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
@require_non_xpu
@require_torch_gpu
@require_torch_non_multi_gpu
@require_torch_tensorrt_fx
def test_torchdynamo_full_eval(self):