From 18e0cae207a38d2c430b5fa08f9597312d1c1ab3 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Thu, 3 Jul 2025 11:17:27 +0200 Subject: [PATCH] Fix many HPU failures in the CI (#39066) * more torch.hpu patches * increase top_k because it results in flaky behavior when Tempreture, TopP and TopK are used together, which ends up killing beams early. * remove temporal fix * fix scatter operation when input and src are the same * trigger * fix and reduce * skip finding batch size as it makes the hpu go loco * fix fsdp (yay all are passing) * fix checking equal nan values * style * remove models list * order * rename to cuda_extensions * Update src/transformers/trainer.py --- .../workflows/self-scheduled-intel-gaudi.yml | 35 ++++----- .../self-scheduled-intel-gaudi3-caller.yml | 4 +- src/transformers/utils/import_utils.py | 75 +++++++++++-------- tests/trainer/test_trainer.py | 10 +++ utils/split_model_tests.py | 1 + 5 files changed, 71 insertions(+), 54 deletions(-) diff --git a/.github/workflows/self-scheduled-intel-gaudi.yml b/.github/workflows/self-scheduled-intel-gaudi.yml index 2db5ece064b..ad14d66b58b 100644 --- a/.github/workflows/self-scheduled-intel-gaudi.yml +++ b/.github/workflows/self-scheduled-intel-gaudi.yml @@ -84,8 +84,6 @@ jobs: machine_type: ${{ matrix.machine_type }} folder_slices: ${{ needs.setup.outputs.folder_slices }} runner: ${{ inputs.runner_scale_set }}-${{ matrix.machine_type }} - report_name_prefix: run_models_gpu - secrets: inherit run_trainer_and_fsdp_gpu: @@ -104,11 +102,10 @@ jobs: folder_slices: ${{ needs.setup.outputs.folder_slices }} runner: ${{ inputs.runner_scale_set }}-${{ matrix.machine_type }} report_name_prefix: run_trainer_and_fsdp_gpu - secrets: inherit - run_pipelines_gpu: - if: ${{ inputs.job == 'run_pipelines_gpu' }} + run_pipelines_torch_gpu: + if: ${{ inputs.job == 'run_pipelines_torch_gpu' }} name: Pipelines strategy: fail-fast: false @@ -161,20 +158,20 @@ jobs: - name: Run all pipeline tests on Intel Gaudi run: | - python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_pipelines_gpu_test_reports tests/pipelines -m "not not_device_test" + python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports tests/pipelines -m "not not_device_test" - name: Failure short reports if: ${{ failure() }} continue-on-error: true run: | - cat reports/${{ env.machine_type }}_run_pipelines_gpu_test_reports/failures_short.txt + cat reports/${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports/failures_short.txt - - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_pipelines_gpu_test_reports" + - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports" if: ${{ always() }} uses: actions/upload-artifact@v4 with: - name: ${{ env.machine_type }}_run_pipelines_gpu_test_reports - path: reports/${{ env.machine_type }}_run_pipelines_gpu_test_reports + name: ${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports + path: reports/${{ env.machine_type }}_run_pipelines_torch_gpu_test_reports run_examples_gpu: if: ${{ inputs.job == 'run_examples_gpu' }} @@ -248,8 +245,8 @@ jobs: name: ${{ env.machine_type }}_run_examples_gpu_test_reports path: reports/${{ env.machine_type }}_run_examples_gpu_test_reports - run_deepspeed_gpu: - if: ${{ inputs.job == 'run_deepspeed_gpu' }} + run_torch_cuda_extensions_gpu: + if: ${{ inputs.job == 'run_torch_cuda_extensions_gpu' }} name: Intel Gaudi deepspeed tests strategy: fail-fast: false @@ -305,20 +302,20 @@ jobs: - name: Run all deepspeed tests on intel Gaudi run: | - python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_deepspeed_gpu_test_reports tests/deepspeed -m "not not_device_test" + python3 -m pytest -v --make-reports=${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports tests/deepspeed -m "not not_device_test" - name: Failure short reports if: ${{ failure() }} continue-on-error: true run: | - cat reports/${{ env.machine_type }}_run_deepspeed_gpu_test_reports/failures_short.txt + cat reports/${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports/failures_short.txt - - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_deepspeed_gpu_test_reports" + - name: "Test suite reports artifacts: ${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports" if: ${{ always() }} uses: actions/upload-artifact@v4 with: - name: ${{ env.machine_type }}_run_deepspeed_gpu_test_reports - path: reports/${{ env.machine_type }}_run_deepspeed_gpu_test_reports + name: ${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports + path: reports/${{ env.machine_type }}_run_torch_cuda_extensions_gpu_test_reports send_results: name: Slack Report @@ -327,8 +324,8 @@ jobs: setup, run_models_gpu, run_examples_gpu, - run_pipelines_gpu, - run_deepspeed_gpu, + run_torch_cuda_extensions_gpu, + run_pipelines_torch_gpu, run_trainer_and_fsdp_gpu, ] if: ${{ always() }} diff --git a/.github/workflows/self-scheduled-intel-gaudi3-caller.yml b/.github/workflows/self-scheduled-intel-gaudi3-caller.yml index 83cb89290d3..8a3d70c4d43 100644 --- a/.github/workflows/self-scheduled-intel-gaudi3-caller.yml +++ b/.github/workflows/self-scheduled-intel-gaudi3-caller.yml @@ -23,7 +23,7 @@ jobs: name: Pipeline CI uses: ./.github/workflows/self-scheduled-intel-gaudi.yml with: - job: run_pipelines_gpu + job: run_pipelines_torch_gpu ci_event: Scheduled CI (Intel) - Gaudi3 runner_scale_set: itac-bm-emr-gaudi3-dell slack_report_channel: "#transformers-ci-daily-intel-gaudi3" @@ -47,7 +47,7 @@ jobs: name: DeepSpeed CI uses: ./.github/workflows/self-scheduled-intel-gaudi.yml with: - job: run_deepspeed_gpu + job: run_torch_cuda_extensions_gpu ci_event: Scheduled CI (Intel) - Gaudi3 runner_scale_set: itac-bm-emr-gaudi3-dell slack_report_channel: "#transformers-ci-daily-intel-gaudi3" diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 2c28cbde292..5c9c6d5690d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -865,50 +865,59 @@ def is_torch_hpu_available(): if not hasattr(torch, "hpu") or not torch.hpu.is_available(): return False - import habana_frameworks.torch.utils.experimental as htexp # noqa: F401 - - # IlyasMoutawwakil: We patch masked_fill_ for int64 tensors to avoid a bug on Gaudi1 - # synNodeCreateWithId failed for node: masked_fill_fwd_i64 with synStatus 26 [Generic failure] - # This can be removed once Gaudi1 support is discontinued but for now we need it to keep using - # dl1.24xlarge Gaudi1 instances on AWS for testing. - # check if the device is Gaudi1 (vs Gaudi2, Gaudi3). - if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi: - original_masked_fill_ = torch.Tensor.masked_fill_ - - def patched_masked_fill_(self, mask, value): - if self.dtype == torch.int64: - logger.warning_once( - "In-place tensor.masked_fill_(mask, value) is not supported for int64 tensors on Gaudi1. " - "This operation will be performed out-of-place using tensor[mask] = value." - ) - self[mask] = value - else: - original_masked_fill_(self, mask, value) - - torch.Tensor.masked_fill_ = patched_masked_fill_ - # We patch torch.gather for int64 tensors to avoid a bug on Gaudi # Graph compile failed with synStatus 26 [Generic failure] # This can be removed once bug is fixed but for now we need it. - original_gather = torch.Tensor.gather + original_gather = torch.gather def patched_gather(input: torch.Tensor, dim: int, index: torch.LongTensor) -> torch.Tensor: if input.dtype == torch.int64 and input.device.type == "hpu": - logger.warning_once( - "torch.gather is not supported for int64 tensors on Gaudi. " - "This operation will be performed patched_gather using indexing." - ) - - idx = [torch.arange(size, device=input.device, dtype=input.dtype) for size in input.shape] - idx[dim] = index - idx = tuple(idx) - output = input[idx] - return output + return original_gather(input.to(torch.int32), dim, index).to(torch.int64) else: return original_gather(input, dim, index) + torch.gather = patched_gather torch.Tensor.gather = patched_gather + original_take_along_dim = torch.take_along_dim + + def patched_take_along_dim( + input: torch.Tensor, indices: torch.LongTensor, dim: Optional[int] = None + ) -> torch.Tensor: + if input.dtype == torch.int64 and input.device.type == "hpu": + return original_take_along_dim(input.to(torch.int32), indices, dim).to(torch.int64) + else: + return original_take_along_dim(input, indices, dim) + + torch.take_along_dim = patched_take_along_dim + + original_cholesky = torch.linalg.cholesky + + def safe_cholesky(A, *args, **kwargs): + output = original_cholesky(A, *args, **kwargs) + + if torch.isnan(output).any(): + jitter_value = 1e-9 + diag_jitter = torch.eye(A.size(-1), dtype=A.dtype, device=A.device) * jitter_value + output = original_cholesky(A + diag_jitter, *args, **kwargs) + + return output + + torch.linalg.cholesky = safe_cholesky + + original_scatter = torch.scatter + + def patched_scatter( + input: torch.Tensor, dim: int, index: torch.Tensor, src: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + if input.device.type == "hpu" and input is src: + return original_scatter(input, dim, index, src.clone(), *args, **kwargs) + else: + return original_scatter(input, dim, index, src, *args, **kwargs) + + torch.scatter = patched_scatter + torch.Tensor.scatter = patched_scatter + # IlyasMoutawwakil: we patch torch.compile to use the HPU backend by default # https://github.com/huggingface/transformers/pull/38790#discussion_r2157043944 # This is necessary for cases where torch.compile is used as a decorator (defaulting to inductor) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 878940b937f..1a7f5120253 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -662,6 +662,11 @@ class TrainerIntegrationCommon: metrics = trainer.evaluate() self.assertEqual(metrics[metric], best_value) + def remove_nan_logs(self, log): + for key in list(log.keys()): + if log[key] != log[key]: # Check if the value is NaN + del log[key] + def check_trainer_state_are_the_same(self, trainer_state, trainer_state1): # We'll pop things so operate on copies. state = trainer_state.copy() @@ -675,6 +680,10 @@ class TrainerIntegrationCommon: for key in skip_log_keys: _ = log.pop(key, None) _ = log1.pop(key, None) + + self.remove_nan_logs(log) + self.remove_nan_logs(log1) + self.assertEqual(log, log1) def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True): @@ -3174,6 +3183,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): self.assertAlmostEqual(b, b1, delta=1e-5) @slow + @require_non_hpu @require_accelerate @require_torch_non_multi_accelerator def test_auto_batch_size_finder(self): diff --git a/utils/split_model_tests.py b/utils/split_model_tests.py index e5083aaeb46..3539a2fb317 100644 --- a/utils/split_model_tests.py +++ b/utils/split_model_tests.py @@ -62,4 +62,5 @@ if __name__ == "__main__": start = end end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0) model_splits.append(d[start:end]) + print(model_splits)