From 89f6956015a42ab32b35de2a6055ea65b5ca53d4 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 12 Mar 2025 09:08:12 +0100 Subject: [PATCH] HPU support (#36424) * test * fix * fix * skip some and run some first * test fsdp * fix * patches for generate * test distributed * copy * don't test distributed loss for hpu * require fp16 and run first * changes from marc's PR fixing zero3 * better alternative * return True when fp16 support on gaudi without creating bridge * fix * fix tested dtype in deepspeed inference test * test * fix * test * fix * skip * require fp16 * run first fsdp * Apply suggestions from code review * address comments * address comments and refactor test * reduce precison * avoid doing gaudi1 specific stuff in the genreation loop * document test_gradient_accumulation_loss_alignment_with_model_loss test a bit more --- setup.py | 2 + src/transformers/__init__.py | 2 + src/transformers/commands/env.py | 5 ++ src/transformers/dependency_versions_table.py | 1 + .../modeling_encoder_decoder.py | 2 + src/transformers/pipelines/base.py | 3 + .../quantizers/quantizer_bnb_4bit.py | 3 + src/transformers/testing_utils.py | 60 +++++++++++++++ src/transformers/trainer.py | 14 +++- src/transformers/trainer_utils.py | 21 ++++++ src/transformers/training_args.py | 21 ++++-- src/transformers/utils/__init__.py | 5 ++ src/transformers/utils/import_utils.py | 64 ++++++++++++++++ tests/deepspeed/test_deepspeed.py | 32 ++++++-- tests/fsdp/test_fsdp.py | 43 +++++++---- tests/test_modeling_common.py | 2 +- tests/trainer/test_trainer.py | 74 +++++++++++-------- tests/trainer/test_trainer_distributed.py | 53 ++----------- tests/trainer/test_trainer_fsdp.py | 69 +++++++++-------- 19 files changed, 337 insertions(+), 139 deletions(-) diff --git a/setup.py b/setup.py index e44c6703b48..fcb62c61579 100644 --- a/setup.py +++ b/setup.py @@ -152,6 +152,7 @@ _deps = [ "pytest-asyncio", "pytest-timeout", "pytest-xdist", + "pytest-order", "python>=3.9.0", "ray[tune]>=2.7.0", "regex!=2019.12.17", @@ -324,6 +325,7 @@ extras["testing"] = ( "pytest-asyncio", "pytest-rich", "pytest-xdist", + "pytest-order", "timeout-decorator", "parameterized", "psutil", diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9ce1fe1378b..da8b1cacaa7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1016,6 +1016,7 @@ _import_structure = { "is_timm_available", "is_tokenizers_available", "is_torch_available", + "is_torch_hpu_available", "is_torch_mlu_available", "is_torch_musa_available", "is_torch_neuroncore_available", @@ -6243,6 +6244,7 @@ if TYPE_CHECKING: is_timm_available, is_tokenizers_available, is_torch_available, + is_torch_hpu_available, is_torch_mlu_available, is_torch_musa_available, is_torch_neuroncore_available, diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index 855bbc961bc..4162f21e954 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -30,6 +30,7 @@ from ..utils import ( is_safetensors_available, is_tf_available, is_torch_available, + is_torch_hpu_available, is_torch_npu_available, ) from . import BaseTransformersCLICommand @@ -94,6 +95,7 @@ class EnvironmentCommand(BaseTransformersCLICommand): pt_version = torch.__version__ pt_cuda_available = torch.cuda.is_available() pt_npu_available = is_torch_npu_available() + pt_hpu_available = is_torch_hpu_available() tf_version = "not installed" tf_cuda_available = "NA" @@ -149,6 +151,9 @@ class EnvironmentCommand(BaseTransformersCLICommand): if pt_cuda_available: info["Using GPU in script?"] = "" info["GPU type"] = torch.cuda.get_device_name() + elif pt_hpu_available: + info["Using HPU in script?"] = "" + info["HPU type"] = torch.hpu.get_device_name() elif pt_npu_available: info["Using NPU in script?"] = "" info["NPU type"] = torch.npu.get_device_name() diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 61b322b7363..28ae4463667 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -58,6 +58,7 @@ deps = { "pytest-asyncio": "pytest-asyncio", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", + "pytest-order": "pytest-order", "python": "python>=3.9.0", "ray[tune]": "ray[tune]>=2.7.0", "regex": "regex!=2019.12.17", diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 9ab4b7f2ced..decc4f8df0f 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -598,6 +598,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin): kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + if "num_items_in_batch" in kwargs_encoder: + kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None) if encoder_outputs is None: encoder_outputs = self.encoder( diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index d3ee4e871e2..70b2ec8ba52 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -45,6 +45,7 @@ from ..utils import ( is_tf_available, is_torch_available, is_torch_cuda_available, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, @@ -963,6 +964,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin): self.device = torch.device(f"cuda:{device}") elif is_torch_npu_available(): self.device = torch.device(f"npu:{device}") + elif is_torch_hpu_available(): + self.device = torch.device(f"hpu:{device}") elif is_torch_xpu_available(check_device=True): self.device = torch.device(f"xpu:{device}") elif is_torch_mps_available(): diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 8657bda1662..ab04a295460 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -29,6 +29,7 @@ from ..utils import ( is_accelerate_available, is_bitsandbytes_available, is_torch_available, + is_torch_hpu_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -269,6 +270,8 @@ class Bnb4BitHfQuantizer(HfQuantizer): device_map = {"": torch.cuda.current_device()} elif is_torch_npu_available(): device_map = {"": f"npu:{torch.npu.current_device()}"} + elif is_torch_hpu_available(): + device_map = {"": f"hpu:{torch.hpu.current_device()}"} elif is_torch_xpu_available(): device_map = {"": f"xpu:{torch.xpu.current_device()}"} else: diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 76d25f61cb1..f6577469bf4 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -141,6 +141,7 @@ from .utils import ( is_torch_deterministic, is_torch_fp16_available_on_device, is_torch_greater_or_equal, + is_torch_hpu_available, is_torch_neuroncore_available, is_torch_npu_available, is_torch_sdpa_available, @@ -858,6 +859,13 @@ def require_torch_multi_npu(test_case): return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case) +def require_non_hpu(test_case): + """ + Decorator marking a test that should be skipped for HPU. + """ + return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case) + + def require_torch_xpu(test_case): """ Decorator marking a test that requires XPU (in PyTorch). @@ -889,6 +897,19 @@ def require_torch_multi_xpu(test_case): return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) +def require_torch_multi_hpu(test_case): + """ + Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without + multiple HPUs. + + To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu" + """ + if not is_torch_hpu_available(): + return unittest.skip(reason="test requires PyTorch HPU")(test_case) + + return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case) + + if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch @@ -917,6 +938,10 @@ if is_torch_available(): raise ValueError( f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment." ) + if torch_device == "hpu" and not is_torch_hpu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment." + ) try: # try creating device to see if provided device is valid @@ -929,6 +954,8 @@ if is_torch_available(): torch_device = "cuda" elif _run_third_party_device_tests and is_torch_npu_available(): torch_device = "npu" + elif _run_third_party_device_tests and is_torch_hpu_available(): + torch_device = "hpu" elif _run_third_party_device_tests and is_torch_xpu_available(): torch_device = "xpu" else: @@ -2565,6 +2592,20 @@ def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2): return decorator +def run_first(test_case): + """ + Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator + are garanteed to run first. + + This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a + single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device + allocation conflicts. + """ + import pytest + + return pytest.mark.order(1)(test_case) + + def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): """ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. @@ -2853,6 +2894,25 @@ else: BACKEND_EMPTY_CACHE = {"default": None} BACKEND_DEVICE_COUNT = {"default": lambda: 0} +if is_torch_hpu_available(): + BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed + BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count + +if is_torch_npu_available(): + BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache + BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed + BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count + +if is_torch_xpu_available(): + BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache + BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed + BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count + +if is_torch_xla_available(): + BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache + BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed + BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count + def backend_manual_seed(device: str, seed: int): return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0eca0f40d0a..79109fbd667 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -166,6 +166,7 @@ from .utils import ( is_sagemaker_mp_enabled, is_schedulefree_available, is_torch_compile_available, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, @@ -3141,9 +3142,10 @@ class Trainer: set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed) if is_torch_npu_available(): set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed) + if is_torch_hpu_available(): + set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed) if is_torch_mlu_available(): set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed) - if is_torch_musa_available(): set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed) @@ -3255,6 +3257,12 @@ class Trainer: else: rng_states["npu"] = torch.npu.random.get_rng_state() + if is_torch_hpu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["hpu"] = torch.hpu.random.get_rng_state_all() + else: + rng_states["hpu"] = torch.hpu.random.get_rng_state() + if is_torch_mlu_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: rng_states["mlu"] = torch.mlu.random.get_rng_state_all() @@ -3725,6 +3733,10 @@ class Trainer: torch.npu.empty_cache() elif is_torch_mps_available(min_version="2.0"): torch.mps.empty_cache() + elif is_torch_hpu_available(): + logger.warning( + "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()." + ) else: torch.cuda.empty_cache() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index f7edded7d43..982f7b7c028 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -35,6 +35,7 @@ from .utils import ( is_tf_available, is_torch_available, is_torch_cuda_available, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, @@ -113,6 +114,8 @@ def set_seed(seed: int, deterministic: bool = False): torch.musa.manual_seed_all(seed) if is_torch_npu_available(): torch.npu.manual_seed_all(seed) + if is_torch_hpu_available(): + torch.hpu.manual_seed_all(seed) if is_torch_xpu_available(): torch.xpu.manual_seed_all(seed) if is_tf_available(): @@ -506,6 +509,11 @@ class TrainerMemoryTracker: elif is_torch_npu_available(): import torch + self.torch = torch + self.gpu = {} + elif is_torch_hpu_available(): + import torch + self.torch = torch self.gpu = {} else: @@ -573,6 +581,10 @@ class TrainerMemoryTracker: elif is_torch_npu_available(): self.torch.npu.reset_peak_memory_stats() self.torch.npu.empty_cache() + elif is_torch_hpu_available(): + self.torch.hpu.reset_peak_memory_stats() + # not available on hpu as it reserves all device memory for the current process + # self.torch.hpu.empty_cache() elif is_torch_mps_available(): self.torch.mps.empty_cache() @@ -588,6 +600,8 @@ class TrainerMemoryTracker: self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() elif is_torch_npu_available(): self.gpu_mem_used_at_start = self.torch.npu.memory_allocated() + elif is_torch_hpu_available(): + self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated() elif is_torch_mps_available(): self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory() @@ -623,6 +637,10 @@ class TrainerMemoryTracker: self.torch.xpu.empty_cache() elif is_torch_npu_available(): self.torch.npu.empty_cache() + elif is_torch_hpu_available(): + # not available on hpu as it reserves all device memory for the current process + # self.torch.npu.empty_cache() + pass elif is_torch_mps_available(): self.torch.mps.empty_cache() @@ -648,6 +666,9 @@ class TrainerMemoryTracker: elif is_torch_npu_available(): self.gpu_mem_used_now = self.torch.npu.memory_allocated() self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated() + elif is_torch_hpu_available(): + self.gpu_mem_used_now = self.torch.hpu.memory_allocated() + self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated() elif is_torch_mps_available(): self.gpu_mem_used_now = self.torch.mps.current_allocated_memory() # self.torch.mps.max_memory_allocated() does not exist yet diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index be9f9f7f93b..d2ec76091f3 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -48,6 +48,7 @@ from .utils import ( is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, @@ -260,9 +261,9 @@ class TrainingArguments: prediction_loss_only (`bool`, *optional*, defaults to `False`): When performing evaluation and generating predictions, only returns the loss. per_device_train_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training. + The batch size per device accelerator core/CPU for training. per_device_eval_batch_size (`int`, *optional*, defaults to 8): - The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation. + The batch size per device accelerator core/CPU for evaluation. gradient_accumulation_steps (`int`, *optional*, defaults to 1): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. @@ -275,7 +276,7 @@ class TrainingArguments: eval_accumulation_steps (`int`, *optional*): Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If - left unset, the whole predictions are accumulated on GPU/NPU/TPU before being moved to the CPU (faster but + left unset, the whole predictions are accumulated on the device accelerator before being moved to the CPU (faster but requires more memory). eval_delay (`float`, *optional*): Number of epochs or steps to wait for before the first evaluation can be performed, depending on the @@ -853,10 +854,10 @@ class TrainingArguments: ) per_device_train_batch_size: int = field( - default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."} + default=8, metadata={"help": "Batch size per device accelerator core/CPU for training."} ) per_device_eval_batch_size: int = field( - default=8, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for evaluation."} + default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."} ) per_gpu_train_batch_size: Optional[int] = field( @@ -1044,7 +1045,7 @@ class TrainingArguments: use_cpu: bool = field( default=False, metadata={ - "help": "Whether or not to use cpu. If set to False, we will use cuda/tpu/mps/npu device if available." + "help": "Whether or not to use cpu. If left to False, we will use the available torch device/backend (cuda/mps/xpu/hpu etc.)" }, ) use_mps_device: bool = field( @@ -1830,7 +1831,10 @@ class TrainingArguments: if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile: self.torch_compile = True if self.torch_compile and self.torch_compile_backend is None: - self.torch_compile_backend = "inductor" + if not self.use_cpu and is_torch_hpu_available(): + self.torch_compile_backend = "hpu_backend" + else: + self.torch_compile_backend = "inductor" # accelerate integration for torch compile if self.torch_compile: @@ -2312,6 +2316,9 @@ class TrainingArguments: elif is_torch_npu_available(): device = torch.device("npu:0") torch.npu.set_device(device) + elif is_torch_hpu_available(): + device = torch.device("hpu:0") + torch.hpu.set_device(device) else: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 85ed61339f3..9561666db76 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -148,6 +148,7 @@ from .import_utils import ( is_gguf_available, is_gptqmodel_available, is_grokadamw_available, + is_habana_gaudi1, is_hadamard_available, is_hqq_available, is_in_notebook, @@ -218,6 +219,7 @@ from .import_utils import ( is_torch_fx_available, is_torch_fx_proxy, is_torch_greater_or_equal, + is_torch_hpu_available, is_torch_mlu_available, is_torch_mps_available, is_torch_musa_available, @@ -316,6 +318,9 @@ def get_available_devices() -> FrozenSet[str]: if is_torch_npu_available(): devices.add("npu") + if is_torch_hpu_available(): + devices.add("hpu") + if is_torch_mlu_available(): devices.add("mlu") diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 8b45028542d..f114b925482 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -542,6 +542,12 @@ def is_torch_fp16_available_on_device(device): if not is_torch_available(): return False + if is_torch_hpu_available(): + if is_habana_gaudi1(): + return False + else: + return True + import torch try: @@ -573,6 +579,9 @@ def is_torch_bf16_available_on_device(device): if device == "cuda": return is_torch_bf16_gpu_available() + if device == "hpu": + return True + try: x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device) _ = x @ x @@ -773,6 +782,61 @@ def is_torch_musa_available(check_device=False): return hasattr(torch, "musa") and torch.musa.is_available() +@lru_cache +def is_torch_hpu_available(): + "Checks if `torch.hpu` is available and potentially if a HPU is in the environment" + if ( + not _torch_available + or importlib.util.find_spec("habana_frameworks") is None + or importlib.util.find_spec("habana_frameworks.torch") is None + ): + return False + + torch_hpu_min_version = "1.5.0" + if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_hpu_min_version): + return False + + import torch + + if not hasattr(torch, "hpu") or not torch.hpu.is_available(): + return False + + import habana_frameworks.torch.utils.experimental as htexp # noqa: F401 + + # IlyasMoutawwakil: We patch masked_fill_ for int64 tensors to avoid a bug on Gaudi1 + # synNodeCreateWithId failed for node: masked_fill_fwd_i64 with synStatus 26 [Generic failure] + # This can be removed once Gaudi1 support is discontinued but for now we need it to keep using + # dl1.24xlarge Gaudi1 instances on AWS for testing. + # check if the device is Gaudi1 (vs Gaudi2, Gaudi3). + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi: + original_masked_fill_ = torch.Tensor.masked_fill_ + + def patched_masked_fill_(self, mask, value): + if self.dtype == torch.int64: + logger.warning( + "In-place tensor.masked_fill_(mask, value) is not supported for int64 tensors on Gaudi1. " + "This operation will be performed out-of-place using tensor[mask] = value." + ) + self[mask] = value + else: + original_masked_fill_(self, mask, value) + + torch.Tensor.masked_fill_ = patched_masked_fill_ + + return True + + +@lru_cache +def is_habana_gaudi1(): + if not is_torch_hpu_available(): + return False + + import habana_frameworks.torch.utils.experimental as htexp # noqa: F401 + + # Check if the device is Gaudi1 (vs Gaudi2, Gaudi3) + return htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi + + def is_torchdynamo_available(): if not is_torch_available(): return False diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 80a926f08db..003e635a108 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -45,12 +45,14 @@ from transformers.testing_utils import ( require_deepspeed, require_optuna, require_torch_accelerator, + require_torch_fp16, require_torch_multi_accelerator, + run_first, slow, torch_device, ) from transformers.trainer_utils import get_last_checkpoint, set_seed -from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_device +from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_device, is_torch_fp16_available_on_device if is_torch_available(): @@ -150,10 +152,12 @@ optims = [HF_OPTIM, DS_OPTIM] schedulers = [HF_SCHEDULER, DS_SCHEDULER] stages = [ZERO2, ZERO3] + +dtypes = [] if is_torch_bf16_available_on_device(torch_device): - dtypes = [FP16, BF16] -else: - dtypes = [FP16] + dtypes.append(BF16) +if is_torch_fp16_available_on_device(torch_device): + dtypes.append(FP16) def parameterized_custom_name_func(func, param_num, param): @@ -228,6 +232,7 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): AutoModel.from_pretrained(T5_TINY) self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) + @require_torch_fp16 @require_torch_accelerator def test_init_zero3_fp16(self): # test that zero.Init() works correctly under zero3/fp16 @@ -456,6 +461,7 @@ class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus): @require_deepspeed +@require_torch_fp16 @require_torch_accelerator class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon): """ @@ -714,7 +720,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T # dynamic loss scale value set to: # "fp16.initial_scale_power": 1 # plus having the same WarmupLR's warmup_min_lr == warmup_max_lr in the config file - # but for some reason going to train_len=64 the weights, weights start to mismatch with this setup. + # but for some reason going to train_len=64, the weights start to mismatch with this setup. # the culprit seems to be `initial_scale_power` - putting it back to its default 32 keeps the weights identical train_len = 64 @@ -757,8 +763,12 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T # training with half the batch size but accumulation steps as 2 should give the same # weights, but sometimes get a slight difference still of 1e-6 - self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, places=5) - self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5) + if torch_device == "hpu": + self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, delta=1e-4) + self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, delta=1e-4) + else: + self.assertAlmostEqual(no_grad_accum_a, yes_grad_accum_a, places=5) + self.assertAlmostEqual(no_grad_accum_b, yes_grad_accum_b, places=5) # Relative difference. See the note above how to get identical loss on a small bs self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3) @@ -1100,6 +1110,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T @slow +@run_first @require_deepspeed @require_torch_accelerator class TestDeepSpeedWithLauncher(TestCasePlus): @@ -1126,6 +1137,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): def test_basic_distributed(self, stage, dtype): self.run_and_check(stage=stage, dtype=dtype, distributed=True) + @require_torch_fp16 def test_do_eval_no_train(self): # testing only zero3 since zero2 makes no sense with inference self.run_and_check( @@ -1199,12 +1211,15 @@ class TestDeepSpeedWithLauncher(TestCasePlus): if dtype == "bf16" and not is_torch_bf16_available_on_device(torch_device): self.skipTest(reason="test requires bfloat16 hardware support") + if dtype == "fp16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(reason="test requires fp16 hardware support") + # this is just inference, so no optimizer should be loaded # it only works for z3 (makes no sense with z1-z2) fp32 = True if dtype == "fp32" else False self.run_and_check( stage=ZERO3, - dtype=FP16, + dtype=dtype, model_name=T5_TINY, distributed=True, do_train=False, @@ -1381,6 +1396,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus): # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) + @require_torch_fp16 def test_clm_from_config_zero3_fp16(self): # this test exercises AutoModel.from_config(config) - to ensure zero.Init is called diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index f5af373f49b..cce33cc7e6c 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -33,12 +33,17 @@ from transformers.testing_utils import ( require_fsdp, require_torch_accelerator, require_torch_multi_accelerator, + run_first, slow, torch_device, ) from transformers.trainer_callback import TrainerState from transformers.trainer_utils import FSDPOption, set_seed -from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device +from transformers.utils import ( + is_accelerate_available, + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, +) if is_torch_available(): @@ -49,14 +54,19 @@ else: # default torch.distributed port DEFAULT_MASTER_PORT = "10999" -dtypes = ["fp16"] + +dtypes = [] if is_torch_bf16_available_on_device(torch_device): dtypes += ["bf16"] +if is_torch_fp16_available_on_device(torch_device): + dtypes += ["fp16"] + sharding_strategies = ["full_shard", "shard_grad_op"] state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"] -set_seed(42) params = list(itertools.product(sharding_strategies, dtypes)) +set_seed(42) + def get_master_port(real_launcher=False): """ @@ -140,13 +150,13 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): } self.fsdp_config = { - "backward_prefetch": "backward_pre", - "forward_prefetch": "False", - "limit_all_gathers": "False", - "use_orig_params": "True", - "sync_module_states": "True", - "cpu_ram_efficient_loading": "True", - "activation_checkpointing": "False", + "backward_prefetch": "BACKWARD_PRE", + "forward_prefetch": "false", + "limit_all_gathers": "false", + "use_orig_params": "true", + "sync_module_states": "true", + "cpu_ram_efficient_loading": "true", + "activation_checkpointing": "false", "min_num_params": 1, } @@ -202,7 +212,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): self.assertEqual( os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"], ",".join(fsdp_config["transformer_layer_cls_to_wrap"]) ) - self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"].upper()) + self.assertEqual(os.environ[f"{prefix}BACKWARD_PREFETCH"], fsdp_config["backward_prefetch"]) self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"]) self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"]) self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"]) @@ -213,6 +223,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): @parameterized.expand(params, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator + @run_first @slow def test_basic_run(self, sharding_strategy, dtype): launcher = get_launcher(distributed=True, use_accelerate=False) @@ -225,6 +236,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): @parameterized.expand(params, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator + @run_first @slow def test_basic_run_with_gradient_accumulation(self, sharding_strategy, dtype): launcher = get_launcher(distributed=True, use_accelerate=False) @@ -237,6 +249,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): @parameterized.expand(dtypes) @require_torch_multi_accelerator + @run_first @slow @unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.") def test_basic_run_with_cpu_offload(self, dtype): @@ -250,6 +263,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): @parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator + @run_first @slow def test_training_and_can_resume_normally(self, state_dict_type): output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) @@ -286,10 +300,13 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5) @require_torch_multi_accelerator + @run_first @slow - @require_torch_accelerator - @require_fsdp def test_fsdp_cpu_offloading(self): + # TODO: This file is missing and should be added or the test should be removed + if not os.path.exists("utils/testing_scripts/fsdp_cpu_offloading.py"): + raise unittest.SkipTest("FSDP CPU offloading script not found!") + try: subprocess.run( "accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml", diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0c0005d8287..b8c41c4ed49 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2770,7 +2770,7 @@ class ModelTesterMixin: elif param_device in ["mps"]: self.assertEqual(param.device, torch.device("mps")) else: - # when loaded with device_map, `param_device` are integer values for cuda/xpu/npu/mlu + # when loaded with device_map, `param_device` are integer values for cuda/xpu/hpu/npu/mlu self.assertEqual(param.device, torch.device(f"{torch_device}:{param_device}")) @require_accelerate diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0a78a524f40..beee7fcb48a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -75,6 +75,7 @@ from transformers.testing_utils import ( require_intel_extension_for_pytorch, require_liger_kernel, require_lomo, + require_non_hpu, require_non_xpu, require_optuna, require_peft, @@ -88,6 +89,7 @@ from transformers.testing_utils import ( require_torch, require_torch_accelerator, require_torch_bf16, + require_torch_fp16, require_torch_gpu, require_torch_multi_accelerator, require_torch_non_multi_accelerator, @@ -98,6 +100,7 @@ from transformers.testing_utils import ( require_torchdynamo, require_vision, require_wandb, + run_first, run_test_using_subprocess, slow, torch_device, @@ -119,6 +122,13 @@ from transformers.utils import ( from transformers.utils.hp_naming import TrialShortNamer +if torch_device == "hpu": + RTOL = 1e-3 + ATOL = 1e-3 +else: + RTOL = 1e-5 + ATOL = 1e-5 + if is_torch_available(): import torch from torch import nn @@ -726,11 +736,11 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.alternate_trained_model = (trainer.model.a, trainer.model.b) - def check_trained_model(self, model, alternate_seed=False): + def check_trained_model(self, model, alternate_seed=False, **kwargs): # Checks a training seeded with learning_rate = 0.1 (a, b) = self.alternate_trained_model if alternate_seed else self.default_trained_model - torch.testing.assert_close(model.a, a) - torch.testing.assert_close(model.b, b) + torch.testing.assert_close(model.a, a, **kwargs) + torch.testing.assert_close(model.b, b, **kwargs) def test_reproducible_training(self): # Checks that training worked, model trained and seed made a reproducible training. @@ -812,11 +822,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) - model = AutoModelForCausalLM.from_pretrained(model_name) - state_dict = model.state_dict() - - base_loss_callback = StoreLossCallback() - args_kwargs = { "report_to": "none", "logging_steps": 1, @@ -830,6 +835,10 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): tmp_dir, **args_kwargs, ) + # train with base loss + set_seed(42) + model = AutoModelForCausalLM.from_pretrained(model_name) + base_loss_callback = StoreLossCallback() trainer = Trainer( model, args, @@ -840,16 +849,17 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): assert trainer.model_accepts_loss_kwargs trainer.train() - grad_accum_loss_callback = StoreLossCallback() - with tempfile.TemporaryDirectory() as tmp_dir: args = TrainingArguments( tmp_dir, **args_kwargs, gradient_accumulation_steps=2, per_device_train_batch_size=4, ) + + # train with gradient accumulation set_seed(42) model = AutoModelForCausalLM.from_pretrained(model_name) + grad_accum_loss_callback = StoreLossCallback() trainer = Trainer( model, args, @@ -857,10 +867,12 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): callbacks=[grad_accum_loss_callback], data_collator=data_collator, ) + assert trainer.model_accepts_loss_kwargs trainer.train() + # train with broken loss set_seed(42) - model.load_state_dict(state_dict) + model = AutoModelForCausalLM.from_pretrained(model_name) broken_loss_callback = StoreLossCallback() trainer = Trainer( model, @@ -869,30 +881,28 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): callbacks=[broken_loss_callback], data_collator=data_collator, ) - # disable model_accepts_loss_kwargs + # disable model_accepts_loss_kwargs so that "num_items_in_batch" is not passed to the model trainer.model_accepts_loss_kwargs = False trainer.train() - # Calculate the difference between the base loss and the grad_accum loss - diff_truth = [ - abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses) - ] - diff_broken = [ - abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses) - ] + # Calculate the difference between the base loss and the grad_accum loss + diff_truth = [ + abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses) + ] + diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)] - # all diff truth should be quite close - self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") + # all diff truth should be quite close + self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01") - # max diff broken should be very off - self.assertGreater(max(diff_broken), 1.5, f"Difference {max(diff_broken)} is not greater than 2") + # max diff broken should be very off + self.assertGreater(max(diff_broken), 1.3, f"Difference {max(diff_broken)} is not greater than 1.3") - loss_base = sum(base_loss_callback.losses) - loss_broken = sum(broken_loss_callback.losses) + loss_base = sum(base_loss_callback.losses) + loss_broken = sum(broken_loss_callback.losses) - # mean/sum loss should not vary too much. - relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken) - self.assertLess(relative_diff, 0.2, f"Relative difference {relative_diff} is not within 0.2") + # mean/sum loss should not vary too much. + relative_diff = abs(loss_base - loss_broken) / max(loss_base, loss_broken) + self.assertLess(relative_diff, 0.2, f"Relative difference {relative_diff} is not within 0.2") def test_gradient_accumulation_loss_alignment_with_loss_func(self): set_seed(42) @@ -1214,14 +1224,14 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) - @require_torch_accelerator @require_torch_bf16 + @require_torch_accelerator def test_mixed_bf16(self): # very basic test with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer(learning_rate=0.1, bf16=True, output_dir=tmp_dir) trainer.train() - self.check_trained_model(trainer.model) + self.check_trained_model(trainer.model, atol=ATOL, rtol=RTOL) # --bf16 --half_precision_backend apex can't be used together with tempfile.TemporaryDirectory() as tmp_dir: @@ -3582,6 +3592,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ) @slow + @run_first def test_trainer_eval_mrpc(self): MODEL_ID = "google-bert/bert-base-cased-finetuned-mrpc" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -3598,6 +3609,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertLess(result["eval_loss"], 0.2) @slow + @run_first def test_trainer_eval_multiple(self): MODEL_ID = "openai-community/gpt2" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -3897,6 +3909,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): trainer = get_regression_trainer(skip_memory_metrics=True, output_dir=tmp_dir) self.check_mem_metrics(trainer, self.assertNotIn) + @require_torch_fp16 @require_torch_accelerator def test_fp16_full_eval(self): # this is a sensitive test so let's keep debugging printouts in place for quick diagnosis. @@ -4152,6 +4165,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params) @slow + @require_non_hpu @require_torch_multi_accelerator def test_end_to_end_example(self): # Tests that `translation.py` will run without issues diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 968f800174a..f7f34b83e7c 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -19,12 +19,11 @@ import numpy as np from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available from transformers.testing_utils import ( TestCasePlus, + backend_device_count, execute_subprocess_async, get_torch_dist_unique_port, - require_torch_multi_gpu, - require_torch_multi_xpu, - require_torch_neuroncore, - require_torch_npu, + require_torch_multi_accelerator, + torch_device, ) from transformers.training_args import ParallelMode from transformers.utils import logging @@ -117,38 +116,10 @@ if is_torch_available(): return result -class TestTrainerDistributedNeuronCore(TestCasePlus): - @require_torch_neuroncore - def test_trainer(self): - distributed_args = f"""--nproc_per_node=2 - --master_port={get_torch_dist_unique_port()} - {self.test_file_dir}/test_trainer_distributed.py - """.split() - output_dir = self.get_auto_remove_tmp_dir() - args = f"--output_dir {output_dir}".split() - cmd = ["torchrun"] + distributed_args + args - execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success - any errors would have caused an error in the sub-call - - -class TestTrainerDistributedNPU(TestCasePlus): - @require_torch_npu - def test_trainer(self): - distributed_args = f"""--nproc_per_node=2 - --master_port={get_torch_dist_unique_port()} - {self.test_file_dir}/test_trainer_distributed.py - """.split() - output_dir = self.get_auto_remove_tmp_dir() - args = f"--output_dir {output_dir}".split() - cmd = ["torchrun"] + distributed_args + args - execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success - any errors would have caused an error in the sub-call - - class TestTrainerDistributed(TestCasePlus): - @require_torch_multi_gpu + @require_torch_multi_accelerator def test_trainer(self): - distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} + distributed_args = f"""--nproc_per_node={backend_device_count(torch_device)} --master_port={get_torch_dist_unique_port()} {self.test_file_dir}/test_trainer_distributed.py """.split() @@ -159,20 +130,6 @@ class TestTrainerDistributed(TestCasePlus): # successful return here == success - any errors would have caused an error in the sub-call -@require_torch_multi_xpu -class TestTrainerDistributedXPU(TestCasePlus): - def test_trainer(self): - distributed_args = f"""--nproc_per_node={torch.xpu.device_count()} - --master_port={get_torch_dist_unique_port()} - {self.test_file_dir}/test_trainer_distributed.py - """.split() - output_dir = self.get_auto_remove_tmp_dir() - args = f"--output_dir {output_dir}".split() - cmd = ["torchrun"] + distributed_args + args - execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success - any errors would have caused an error in the sub-call - - if __name__ == "__main__": # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: # diff --git a/tests/trainer/test_trainer_fsdp.py b/tests/trainer/test_trainer_fsdp.py index eca6a30664f..255739a2d7f 100644 --- a/tests/trainer/test_trainer_fsdp.py +++ b/tests/trainer/test_trainer_fsdp.py @@ -17,12 +17,15 @@ from typing import Dict from transformers import is_torch_available from transformers.testing_utils import ( TestCasePlus, + backend_device_count, execute_subprocess_async, get_torch_dist_unique_port, require_accelerate, require_fp8, require_fsdp, - require_torch_multi_gpu, + require_torch_multi_accelerator, + run_first, + torch_device, ) @@ -64,9 +67,10 @@ if is_torch_available(): class TestFSDPTrainer(TestCasePlus): + @require_torch_multi_accelerator @require_accelerate - @require_torch_multi_gpu @require_fsdp + @run_first def test_trainer(self): output_dir = self.get_auto_remove_tmp_dir() cmd = [ @@ -76,7 +80,7 @@ class TestFSDPTrainer(TestCasePlus): "--main_process_port", f"{get_torch_dist_unique_port()}", "--num_processes", - f"{torch.cuda.device_count()}", + f"{backend_device_count(torch_device)}", "--fsdp_transformer_layer_cls_to_wrap", "GPT2Block", f"{self.test_file_dir}/test_trainer_fsdp.py", @@ -90,10 +94,11 @@ class TestFSDPTrainer(TestCasePlus): class TestFSDPTrainerFP8(TestCasePlus): + @require_torch_multi_accelerator @require_accelerate - @require_torch_multi_gpu @require_fsdp @require_fp8 + @run_first def test_trainer(self): output_dir = self.get_auto_remove_tmp_dir() cmd = [ @@ -103,7 +108,7 @@ class TestFSDPTrainerFP8(TestCasePlus): "--main_process_port", f"{get_torch_dist_unique_port()}", "--num_processes", - f"{torch.cuda.device_count()}", + f"{backend_device_count(torch_device)}", "--mixed_precision", "fp8", "--fsdp_transformer_layer_cls_to_wrap", @@ -117,32 +122,34 @@ class TestFSDPTrainerFP8(TestCasePlus): execute_subprocess_async(cmd, env=self.get_env()) # successful return here == success - any errors would have caused an error in the sub-call - class TestFSDPTrainerWrap(TestCasePlus): - @require_accelerate - @require_torch_multi_gpu - @require_fsdp - def test_trainer(self): - output_dir = self.get_auto_remove_tmp_dir() - cmd = [ - "accelerate", - "launch", - "--use_fsdp", - "--main_process_port", - f"{get_torch_dist_unique_port()}", - "--num_processes", - f"{torch.cuda.device_count()}", - "--fsdp_transformer_layer_cls_to_wrap", - "GPT2Block", - f"{self.test_file_dir}/test_trainer_fsdp.py", - "--output_dir", - f"{output_dir}", - "--report_to", - "none", - "--auto_find_batch_size", - "True", - ] - execute_subprocess_async(cmd, env=self.get_env()) - # successful return here == success - any errors would have caused an error in the sub-call + +class TestFSDPTrainerWrap(TestCasePlus): + @require_torch_multi_accelerator + @require_accelerate + @require_fsdp + @run_first + def test_trainer(self): + output_dir = self.get_auto_remove_tmp_dir() + cmd = [ + "accelerate", + "launch", + "--use_fsdp", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "--num_processes", + f"{backend_device_count(torch_device)}", + "--fsdp_transformer_layer_cls_to_wrap", + "GPT2Block", + f"{self.test_file_dir}/test_trainer_fsdp.py", + "--output_dir", + f"{output_dir}", + "--report_to", + "none", + "--auto_find_batch_size", + "True", + ] + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call if __name__ == "__main__":