From a5a0c7b88828a7273bdedbdcaa2c7a252084c0d8 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Mon, 26 May 2025 16:18:53 +0800 Subject: [PATCH] switch to device agnostic device calling for test cases (#38247) * use device agnostic APIs in test cases Signed-off-by: Matrix Yao * fix style Signed-off-by: Matrix Yao * add one more Signed-off-by: YAO Matrix * xpu now supports integer device id, aligning to CUDA behaviors Signed-off-by: Matrix Yao * update to use device_properties Signed-off-by: Matrix Yao * fix style Signed-off-by: Matrix Yao * update comment Signed-off-by: Matrix Yao * fix comments Signed-off-by: Matrix Yao * fix style Signed-off-by: Matrix Yao --------- Signed-off-by: Matrix Yao Signed-off-by: YAO Matrix Co-authored-by: ydshieh --- .../quantizers/quantizer_bnb_4bit.py | 2 +- .../quantizers/quantizer_bnb_8bit.py | 2 +- tests/models/bamba/test_modeling_bamba.py | 15 ++-- tests/models/bloom/test_modeling_bloom.py | 14 +-- tests/models/cohere2/test_modeling_cohere2.py | 9 -- .../deepseek_v3/test_modeling_deepseek_v3.py | 10 --- .../diffllama/test_modeling_diffllama.py | 10 --- tests/models/gemma/test_modeling_gemma.py | 42 ++++----- tests/models/gemma2/test_modeling_gemma2.py | 9 -- tests/models/glm/test_modeling_glm.py | 9 -- tests/models/glm4/test_modeling_glm4.py | 9 -- tests/models/granite/test_modeling_granite.py | 47 ++++------ .../granitemoe/test_modeling_granitemoe.py | 36 +++----- .../test_modeling_granitemoehybrid.py | 10 --- .../test_modeling_granitemoeshared.py | 36 +++----- tests/models/helium/test_modeling_helium.py | 9 -- tests/models/jamba/test_modeling_jamba.py | 65 +++++++------- tests/models/llama4/test_modeling_llama4.py | 6 -- tests/models/mistral/test_modeling_mistral.py | 35 ++++---- tests/models/mixtral/test_modeling_mixtral.py | 89 ++++++++----------- .../models/nemotron/test_modeling_nemotron.py | 10 --- .../aqlm_integration/test_aqlm.py | 9 +- tests/quantization/autoawq/test_awq.py | 17 ++-- .../bitnet_integration/test_bitnet.py | 19 ++-- tests/quantization/bnb/test_4bit.py | 15 ++-- tests/quantization/bnb/test_mixed_int8.py | 18 ++-- .../test_compressed_models.py | 6 +- .../test_compressed_tensors.py | 4 +- .../eetq_integration/test_eetq.py | 3 +- .../fbgemm_fp8/test_fbgemm_fp8.py | 3 +- .../quantization/finegrained_fp8/test_fp8.py | 5 +- tests/quantization/higgs/test_higgs.py | 3 +- tests/quantization/hqq/test_hqq.py | 7 +- .../spqr_integration/test_spqr.py | 9 +- .../torchao_integration/test_torchao.py | 23 ++--- .../vptq_integration/test_vptq.py | 3 +- tests/test_modeling_common.py | 21 +++-- tests/trainer/test_trainer.py | 2 +- tests/utils/test_cache_utils.py | 7 +- 39 files changed, 259 insertions(+), 389 deletions(-) diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 7fb9176c467..8beabfa7346 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -273,7 +273,7 @@ class Bnb4BitHfQuantizer(HfQuantizer): elif is_torch_hpu_available(): device_map = {"": f"hpu:{torch.hpu.current_device()}"} elif is_torch_xpu_available(): - device_map = {"": f"xpu:{torch.xpu.current_device()}"} + device_map = {"": torch.xpu.current_device()} else: device_map = {"": "cpu"} logger.info( diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index cac339b16b9..e0b5811fc7f 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -136,7 +136,7 @@ class Bnb8BitHfQuantizer(HfQuantizer): if torch.cuda.is_available(): device_map = {"": torch.cuda.current_device()} elif is_torch_xpu_available(): - device_map = {"": f"xpu:{torch.xpu.current_device()}"} + device_map = {"": torch.xpu.current_device()} else: device_map = {"": "cpu"} logger.info( diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 7c00a7a030d..8213a85c175 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -28,6 +28,7 @@ from transformers import ( ) from transformers.testing_utils import ( Expectations, + get_device_properties, require_deterministic_for_xpu, require_flash_attn, require_torch, @@ -572,10 +573,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True ) batch = data_collator(features) - batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()} + batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()} res_padded = model(**inputs_dict) - res_padfree = model(**batch_cuda) + res_padfree = model(**batch_accelerator) logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] logits_padfree = res_padfree.logits[0] @@ -594,7 +595,7 @@ class BambaModelIntegrationTest(unittest.TestCase): tokenizer = None # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None + device_properties = None @classmethod def setUpClass(cls): @@ -606,9 +607,7 @@ class BambaModelIntegrationTest(unittest.TestCase): cls.tokenizer.pad_token_id = cls.model.config.pad_token_id cls.tokenizer.padding_side = "left" - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.device_properties = get_device_properties() def test_simple_generate(self): expectations = Expectations( @@ -639,7 +638,7 @@ class BambaModelIntegrationTest(unittest.TestCase): self.assertEqual(output_sentence, expected) # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist - if self.cuda_compute_capability_major_version == 8: + if self.device_properties == ("cuda", 8): with torch.no_grad(): logits = self.model(input_ids=input_ids, logits_to_keep=40).logits @@ -692,7 +691,7 @@ class BambaModelIntegrationTest(unittest.TestCase): self.assertEqual(output_sentences[1], EXPECTED_TEXT[1]) # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist - if self.cuda_compute_capability_major_version == 8: + if self.device_properties == ("cuda", 8): with torch.no_grad(): logits = self.model(input_ids=inputs["input_ids"]).logits diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 014c0f2f58f..787a99c9329 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -390,7 +390,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_simple_generation(self): # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 - # As we leave the default value (True) for allow_fp16_reduced_precision_reduction , the tests failed when running in half-precision with smaller models (560m) + # As we leave the default value (True) for allow_fp16_reduced_precision_reduction, the tests failed when running in half-precision with smaller models (560m) # Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms # This discrepancy is observed only when using small models and seems to be stable for larger models. # Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models. @@ -763,7 +763,6 @@ class BloomEmbeddingTest(unittest.TestCase): @require_torch def test_hidden_states_transformers(self): - cuda_available = torch.cuda.is_available() model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to( torch_device ) @@ -782,7 +781,7 @@ class BloomEmbeddingTest(unittest.TestCase): "max": logits.last_hidden_state.max(dim=-1).values[0][0].item(), } - if cuda_available: + if torch_device == "cuda": self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=4) else: self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3) @@ -791,7 +790,6 @@ class BloomEmbeddingTest(unittest.TestCase): @require_torch def test_logits(self): - cuda_available = torch.cuda.is_available() model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to( torch_device ) # load in bf16 @@ -807,9 +805,5 @@ class BloomEmbeddingTest(unittest.TestCase): output = model(tensor_ids).logits output_gpu_1, output_gpu_2 = output.split(125440, dim=-1) - if cuda_available: - self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) - self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6) - else: - self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) # 1e-06 precision!! - self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6) + self.assertAlmostEqual(output_gpu_1.mean().item(), MEAN_LOGITS_GPU_1, places=6) + self.assertAlmostEqual(output_gpu_2.mean().item(), MEAN_LOGITS_GPU_2, places=6) diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 195be1c23d8..9282c22d413 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -133,15 +133,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): @require_torch_large_gpu class Cohere2IntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] def test_model_bf16(self): model_id = "CohereForAI/c4ai-command-r7b-12-2024" diff --git a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py index 9b796937b08..e0a21002ef0 100644 --- a/tests/models/deepseek_v3/test_modeling_deepseek_v3.py +++ b/tests/models/deepseek_v3/test_modeling_deepseek_v3.py @@ -495,16 +495,6 @@ class DeepseekV3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste @require_torch_accelerator class DeepseekV3IntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - def tearDown(self): # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. cleanup(torch_device, gc_collect=False) diff --git a/tests/models/diffllama/test_modeling_diffllama.py b/tests/models/diffllama/test_modeling_diffllama.py index c738fbf76d1..50525a3ec4e 100644 --- a/tests/models/diffllama/test_modeling_diffllama.py +++ b/tests/models/diffllama/test_modeling_diffllama.py @@ -565,16 +565,6 @@ class DiffLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester @require_torch_accelerator class DiffLlamaIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - def tearDown(self): # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. cleanup(torch_device, gc_collect=False) diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 649c837b9c2..940aa7fedc0 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -21,7 +21,9 @@ from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers.generation.configuration_utils import GenerationConfig from transformers.testing_utils import ( + Expectations, cleanup, + get_device_properties, require_bitsandbytes, require_flash_attn, require_read_token, @@ -105,15 +107,13 @@ class GemmaModelTest(CausalLMModelTest, unittest.TestCase): @require_torch_accelerator class GemmaIntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4) # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None + device_properties = None @classmethod def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.device_properties = get_device_properties() def tearDown(self): # See LlamaIntegrationTest.tearDown(). Can be removed once LlamaIntegrationTest.tearDown() is removed. @@ -270,7 +270,7 @@ class GemmaIntegrationTest(unittest.TestCase): @require_read_token def test_model_7b_fp16(self): - if self.cuda_compute_capability_major_version == 7: + if self.device_properties == ("cuda", 7): self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).") model_id = "google/gemma-7b" @@ -293,7 +293,7 @@ class GemmaIntegrationTest(unittest.TestCase): @require_read_token def test_model_7b_bf16(self): - if self.cuda_compute_capability_major_version == 7: + if self.device_properties == ("cuda", 7): self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).") model_id = "google/gemma-7b" @@ -302,20 +302,16 @@ class GemmaIntegrationTest(unittest.TestCase): # # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, # considering differences in hardware processing and potential deviations in generated text. - EXPECTED_TEXTS = { - 7: [ - """Hello I am doing a project on a 1991 240sx and I am trying to find""", - "Hi today I am going to show you how to make a very simple and easy to make a very simple and", - ], - 8: [ - "Hello I am doing a project for my school and I am trying to make a program that will read a .txt file", - "Hi today I am going to show you how to make a very simple and easy to make a very simple and", - ], - 9: [ - "Hello I am doing a project for my school and I am trying to get a servo to move a certain amount of degrees", - "Hi today I am going to show you how to make a very simple and easy to make DIY light up sign", - ], - } + # fmt: off + EXPECTED_TEXTS = Expectations( + { + ("cuda", 7): ["""Hello I am doing a project on a 1991 240sx and I am trying to find""", "Hi today I am going to show you how to make a very simple and easy to make a very simple and",], + ("cuda", 8): ["Hello I am doing a project for my school and I am trying to make a program that will read a .txt file", "Hi today I am going to show you how to make a very simple and easy to make a very simple and",], + ("rocm", 9): ["Hello I am doing a project for my school and I am trying to get a servo to move a certain amount of degrees", "Hi today I am going to show you how to make a very simple and easy to make DIY light up sign",], + } + ) + # fmt: on + expected_text = EXPECTED_TEXTS.get_expectation() model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( torch_device @@ -326,11 +322,11 @@ class GemmaIntegrationTest(unittest.TestCase): output = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) + self.assertEqual(output_text, expected_text) @require_read_token def test_model_7b_fp16_static_cache(self): - if self.cuda_compute_capability_major_version == 7: + if self.device_properties == ("cuda", 7): self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).") model_id = "google/gemma-7b" diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index cb98a5a0e69..5e78efe540c 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -176,15 +176,6 @@ class Gemma2ModelTest(CausalLMModelTest, unittest.TestCase): @require_torch_accelerator class Gemma2IntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @tooslow @require_read_token diff --git a/tests/models/glm/test_modeling_glm.py b/tests/models/glm/test_modeling_glm.py index e246ea867a0..5438b4d158c 100644 --- a/tests/models/glm/test_modeling_glm.py +++ b/tests/models/glm/test_modeling_glm.py @@ -80,15 +80,6 @@ class GlmIntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] model_id = "THUDM/glm-4-9b" revision = "refs/pr/15" - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] def test_model_9b_fp16(self): EXPECTED_TEXTS = [ diff --git a/tests/models/glm4/test_modeling_glm4.py b/tests/models/glm4/test_modeling_glm4.py index 295954fe20c..d7a8074a5c9 100644 --- a/tests/models/glm4/test_modeling_glm4.py +++ b/tests/models/glm4/test_modeling_glm4.py @@ -82,15 +82,6 @@ class Glm4IntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] model_id = "THUDM/glm-4-0414-9b-chat" revision = "refs/pr/15" - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] def test_model_9b_fp16(self): EXPECTED_TEXTS = [ diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py index be1b5841ff8..c540655577f 100644 --- a/tests/models/granite/test_modeling_granite.py +++ b/tests/models/granite/test_modeling_granite.py @@ -305,16 +305,6 @@ class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch_accelerator class GraniteIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - @slow @require_read_token def test_model_3b_logits_bf16(self): @@ -330,24 +320,24 @@ class GraniteIntegrationTest(unittest.TestCase): # fmt: off EXPECTED_MEANS = Expectations( - { - ("xpu", 3): torch.tensor([[-3.1406, -2.5469, -2.6250, -2.1250, -2.6250, -2.6562, -2.6875, -2.9688]]), - ("cuda", 7): torch.tensor([[-1.9798, -3.1626, -2.8062, -2.3777, -2.7091, -2.2338, -2.5924, -2.3974]]), - ("cuda", 8): torch.tensor([[-3.1406, -2.5469, -2.6250, -2.1250, -2.6250, -2.6562, -2.6875, -2.9688]]), - } - ) + { + ("xpu", 3): torch.tensor([[-3.1406, -2.5469, -2.6250, -2.1250, -2.6250, -2.6562, -2.6875, -2.9688]]), + ("cuda", 7): torch.tensor([[-1.9798, -3.1626, -2.8062, -2.3777, -2.7091, -2.2338, -2.5924, -2.3974]]), + ("cuda", 8): torch.tensor([[-3.1406, -2.5469, -2.6250, -2.1250, -2.6250, -2.6562, -2.6875, -2.9688]]), + } + ) EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1).float(), rtol=1e-2, atol=1e-2) # slicing logits[0, 0, 0:15] EXPECTED_SLICES = Expectations( - { - ("xpu", 3): torch.tensor([[2.2031, -5.0625, -5.0625, -5.0625, -5.0625, -0.9180, -5.0625, -5.0625, -5.0625, -5.0625, -5.5312, -2.1719, -1.7891, -0.4922, -2.5469]]), - ("cuda", 7): torch.tensor([[4.8750, -2.1875, -2.1875, -2.1875, -2.1875, -2.8438, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875]]), - ("cuda", 8): torch.tensor([[2.0938, -5.0312, -5.0312, -5.0312, -5.0312, -1.0469, -5.0312, -5.0312, -5.0312, -5.0312, -5.5625, -2.1875, -1.7891, -0.5820, -2.6250]]), - } - ) + { + ("xpu", 3): torch.tensor([[2.2031, -5.0625, -5.0625, -5.0625, -5.0625, -0.9180, -5.0625, -5.0625, -5.0625, -5.0625, -5.5312, -2.1719, -1.7891, -0.4922, -2.5469]]), + ("cuda", 7): torch.tensor([[4.8750, -2.1875, -2.1875, -2.1875, -2.1875, -2.8438, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875]]), + ("cuda", 8): torch.tensor([[2.0938, -5.0312, -5.0312, -5.0312, -5.0312, -1.0469, -5.0312, -5.0312, -5.0312, -5.0312, -5.5625, -2.1875, -1.7891, -0.5820, -2.6250]]), + } + ) EXPECTED_SLICE = EXPECTED_SLICES.get_expectation() # fmt: on self.assertTrue( @@ -372,12 +362,13 @@ class GraniteIntegrationTest(unittest.TestCase): # fmt: off # Expected mean on dim = -1 EXPECTED_MEANS = Expectations( - { - ("xpu", 3): torch.tensor([[-3.2693, -2.5957, -2.6234, -2.1675, -2.6386, -2.6850, -2.7039, -2.9656]]), - ("cuda", 7): torch.tensor([[-2.0984, -3.1294, -2.8153, -2.3568, -2.7337, -2.2624, -2.6016, -2.4022]]), - ("cuda", 8): torch.tensor([[-3.2934, -2.6019, -2.6258, -2.1691, -2.6394, -2.6876, -2.7032, -2.9688]]), - } - ) + { + ("xpu", 3): torch.tensor([[-3.2693, -2.5957, -2.6234, -2.1675, -2.6386, -2.6850, -2.7039, -2.9656]]), + ("cuda", 7): torch.tensor([[-2.0984, -3.1294, -2.8153, -2.3568, -2.7337, -2.2624, -2.6016, -2.4022]]), + ("cuda", 8): torch.tensor([[-3.2934, -2.6019, -2.6258, -2.1691, -2.6394, -2.6876, -2.7032, -2.9688]]), + } + ) + # fmt: on EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2) diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py index e451ff30c84..ccc0dfd6a51 100644 --- a/tests/models/granitemoe/test_modeling_granitemoe.py +++ b/tests/models/granitemoe/test_modeling_granitemoe.py @@ -304,16 +304,6 @@ class GraniteMoeModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test @require_torch_accelerator class GraniteMoeIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - @slow @require_read_token def test_model_3b_logits(self): @@ -327,24 +317,24 @@ class GraniteMoeIntegrationTest(unittest.TestCase): # fmt: off # Expected mean on dim = -1 EXPECTED_MEANS = Expectations( - { - ("xpu", 3): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), - ("cuda", 7): torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]]), - ("cuda", 8): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), - } - ) + { + ("xpu", 3): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), + ("cuda", 7): torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]]), + ("cuda", 8): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), + } + ) EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2) # slicing logits[0, 0, 0:15] EXPECTED_SLICES = Expectations( - { - ("xpu", 3): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]), - ("cuda", 7): torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892, -2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]]), - ("cuda", 8): torch.tensor([[2.5479, -9.2124, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2162, -9.2122, -6.3101, -3.6224, -3.6377, -5.2542, -5.2524]]), - } - ) + { + ("xpu", 3): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]), + ("cuda", 7): torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892, -2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]]), + ("cuda", 8): torch.tensor([[2.5479, -9.2124, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2162, -9.2122, -6.3101, -3.6224, -3.6377, -5.2542, -5.2524]]), + } + ) EXPECTED_SLICE = EXPECTED_SLICES.get_expectation() # fmt: on @@ -360,6 +350,7 @@ class GraniteMoeIntegrationTest(unittest.TestCase): @slow def test_model_3b_generation(self): # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + # fmt: off EXPECTED_TEXT_COMPLETIONS = Expectations( { ("xpu", 3): ( @@ -378,6 +369,7 @@ class GraniteMoeIntegrationTest(unittest.TestCase): ), } ) + # fmt: on EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() prompt = "Simply put, the theory of relativity states that " diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index 3f4f45017d0..fc3d93a6640 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -105,16 +105,6 @@ class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest. @unittest.skip(reason="GraniteMoeHybrid models are not yet released") @require_torch_gpu class GraniteMoeHybridIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - @slow def test_model_logits(self): input_ids = [31390, 631, 4162, 30, 322, 25342, 432, 1875, 43826, 10066, 688, 225] diff --git a/tests/models/granitemoeshared/test_modeling_granitemoeshared.py b/tests/models/granitemoeshared/test_modeling_granitemoeshared.py index 5de3552c20f..bfd9464c75a 100644 --- a/tests/models/granitemoeshared/test_modeling_granitemoeshared.py +++ b/tests/models/granitemoeshared/test_modeling_granitemoeshared.py @@ -307,16 +307,6 @@ class GraniteMoeSharedModelTest(ModelTesterMixin, GenerationTesterMixin, unittes @require_torch_accelerator class GraniteMoeSharedIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - @slow @require_read_token def test_model_3b_logits(self): @@ -330,24 +320,24 @@ class GraniteMoeSharedIntegrationTest(unittest.TestCase): # fmt: off # Expected mean on dim = -1 EXPECTED_MEANS = Expectations( - { - ("xpu", 3): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), - ("cuda", 7): torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]]), - ("cuda", 8): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), - } - ) + { + ("xpu", 3): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), + ("cuda", 7): torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]]), + ("cuda", 8): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]), + } + ) EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2) # slicing logits[0, 0, 0:15] EXPECTED_SLICES = Expectations( - { - ("xpu", 3): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]), - ("cuda", 7): torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892, -2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]]), - ("cuda", 8): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]), - } - ) + { + ("xpu", 3): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]), + ("cuda", 7): torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892, -2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]]), + ("cuda", 8): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]), + } + ) EXPECTED_SLICE = EXPECTED_SLICES.get_expectation() # fmt: on @@ -363,6 +353,7 @@ class GraniteMoeSharedIntegrationTest(unittest.TestCase): @slow def test_model_3b_generation(self): # ground truth text generated with dola_layers="low", repetition_penalty=1.2 + # fmt: off EXPECTED_TEXT_COMPLETIONS = Expectations( { ("xpu", 3): ( @@ -381,6 +372,7 @@ class GraniteMoeSharedIntegrationTest(unittest.TestCase): ), } ) + # fmt: on EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation() prompt = "Simply put, the theory of relativity states that " diff --git a/tests/models/helium/test_modeling_helium.py b/tests/models/helium/test_modeling_helium.py index f4a555588e9..cb46167bae4 100644 --- a/tests/models/helium/test_modeling_helium.py +++ b/tests/models/helium/test_modeling_helium.py @@ -79,15 +79,6 @@ class HeliumModelTest(GemmaModelTest, unittest.TestCase): # @require_torch_gpu class HeliumIntegrationTest(unittest.TestCase): input_text = ["Hello, today is a great day to"] - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] @require_read_token def test_model_2b(self): diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 993db32378b..cd27180a5cf 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -21,6 +21,8 @@ import pytest from transformers import AutoTokenizer, JambaConfig, is_torch_available from transformers.testing_utils import ( + Expectations, + get_device_properties, require_bitsandbytes, require_flash_attn, require_torch, @@ -554,30 +556,32 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi class JambaModelIntegrationTest(unittest.TestCase): model = None tokenizer = None - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # This variable is used to determine which acclerator are we using for our runners (e.g. A10 or T4) # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None + device_properties = None @classmethod def setUpClass(cls): model_id = "ai21labs/Jamba-tiny-dev" cls.model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) cls.tokenizer = AutoTokenizer.from_pretrained(model_id) - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.device_properties = get_device_properties() @slow def test_simple_generate(self): - # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4. # - # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, # considering differences in hardware processing and potential deviations in generated text. - EXPECTED_TEXTS = { - 7: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", - 8: "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", - 9: "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb", - } + # fmt: off + EXPECTED_TEXTS = Expectations( + { + ("cuda", 7): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", + ("cuda", 8): "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", + ("rocm", 9): "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew llam bb", + } + ) + # fmt: on + expected_sentence = EXPECTED_TEXTS.get_expectation() self.model.to(torch_device) @@ -586,10 +590,10 @@ class JambaModelIntegrationTest(unittest.TestCase): ].to(torch_device) out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) output_sentence = self.tokenizer.decode(out[0, :]) - self.assertEqual(output_sentence, EXPECTED_TEXTS[self.cuda_compute_capability_major_version]) + self.assertEqual(output_sentence, expected_sentence) # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist - if self.cuda_compute_capability_major_version == 8: + if self.device_properties == ("cuda", 8): with torch.no_grad(): logits = self.model(input_ids=input_ids).logits @@ -607,24 +611,19 @@ class JambaModelIntegrationTest(unittest.TestCase): @slow def test_simple_batched_generate_with_padding(self): - # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4. # - # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, # considering differences in hardware processing and potential deviations in generated text. - EXPECTED_TEXTS = { - 7: [ - "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats", - "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed", - ], - 8: [ - "<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", - "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States", - ], - 9: [ - "<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", - "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed", - ], - } + # fmt: off + EXPECTED_TEXTS = Expectations( + { + ("cuda", 7): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh Hebrew cases Cats", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",], + ("cuda", 8): ["<|startoftext|>Hey how are you doing on this lovely evening? I'm so glad you're here.", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a story about a woman who was born in the United States",], + ("rocm", 9): ["<|startoftext|>Hey how are you doing on this lovely evening? Canyon rins hugaughter glamour Rutgers Singh<|reserved_797|>cw algunas", "<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|>Tell me a storyptus Nets Madison El chamadamodern updximVaparsed",], + } + ) + # fmt: on + expected_sentences = EXPECTED_TEXTS.get_expectation() self.model.to(torch_device) @@ -633,11 +632,11 @@ class JambaModelIntegrationTest(unittest.TestCase): ).to(torch_device) out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) output_sentences = self.tokenizer.batch_decode(out) - self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0]) - self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1]) + self.assertEqual(output_sentences[0], expected_sentences[0]) + self.assertEqual(output_sentences[1], expected_sentences[1]) # TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist - if self.cuda_compute_capability_major_version == 8: + if self.device_properties == ("cuda", 8): with torch.no_grad(): logits = self.model(input_ids=inputs["input_ids"]).logits diff --git a/tests/models/llama4/test_modeling_llama4.py b/tests/models/llama4/test_modeling_llama4.py index b349c47e3c4..d9362d397e0 100644 --- a/tests/models/llama4/test_modeling_llama4.py +++ b/tests/models/llama4/test_modeling_llama4.py @@ -38,15 +38,9 @@ if is_torch_available(): @require_read_token class Llama4IntegrationTest(unittest.TestCase): model_id = "meta-llama/Llama-4-Scout-17B-16E" - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None @classmethod def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.model = Llama4ForConditionalGeneration.from_pretrained( "meta-llama/Llama-4-Scout-17B-16E", device_map="auto", diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index bb5a24c3cec..8410bfcfb6e 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -21,8 +21,10 @@ from packaging import version from transformers import AutoTokenizer, MistralConfig, is_torch_available, set_seed from transformers.testing_utils import ( + Expectations, backend_empty_cache, cleanup, + get_device_properties, require_bitsandbytes, require_flash_attn, require_read_token, @@ -110,15 +112,13 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase): @require_torch_accelerator class MistralIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4) # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None + device_properties = None @classmethod def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.device_properties = get_device_properties() def tearDown(self): cleanup(torch_device, gc_collect=True) @@ -136,19 +136,20 @@ class MistralIntegrationTest(unittest.TestCase): EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]]) torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) - # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. - # - # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, + # ("cuda", 8) for A100/A10, and ("cuda", 7) 7 for T4. # considering differences in hardware processing and potential deviations in output. - EXPECTED_SLICE = { - 7: torch.tensor([-5.8828, -5.8633, -0.1042, -4.7266, -5.8828, -5.8789, -5.8789, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -1.0801, 1.7598, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828]), - 8: torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]), - 9: torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]), - } # fmt: skip - - torch.testing.assert_close( - out[0, 0, :30], EXPECTED_SLICE[self.cuda_compute_capability_major_version], atol=1e-4, rtol=1e-4 + # fmt: off + EXPECTED_SLICES = Expectations( + { + ("cuda", 7): torch.tensor([-5.8828, -5.8633, -0.1042, -4.7266, -5.8828, -5.8789, -5.8789, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -1.0801, 1.7598, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828, -5.8828]), + ("cuda", 8): torch.tensor([-5.8711, -5.8555, -0.1050, -4.7148, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -1.0781, 1.7568, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711, -5.8711]), + ("rocm", 9): torch.tensor([-5.8750, -5.8594, -0.1047, -4.7188, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -1.0781, 1.7578, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750, -5.8750]), + } ) + # fmt: on + expected_slice = EXPECTED_SLICES.get_expectation() + + torch.testing.assert_close(out[0, 0, :30], expected_slice, atol=1e-4, rtol=1e-4) @slow @require_bitsandbytes @@ -278,7 +279,7 @@ class MistralIntegrationTest(unittest.TestCase): if version.parse(torch.__version__) < version.parse("2.3.0"): self.skipTest(reason="This test requires torch >= 2.3 to run.") - if self.cuda_compute_capability_major_version == 7: + if self.device_properties == ("cuda", 7): self.skipTest(reason="This test is failing (`torch.compile` fails) on Nvidia T4 GPU.") NUM_TOKENS_TO_GENERATE = 40 diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 532ebb7348a..efe076e70ab 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -19,6 +19,8 @@ import pytest from transformers import MixtralConfig, is_torch_available from transformers.testing_utils import ( + Expectations, + get_device_properties, require_flash_attn, require_torch, require_torch_accelerator, @@ -142,13 +144,11 @@ class MistralModelTest(CausalLMModelTest, unittest.TestCase): class MixtralIntegrationTest(unittest.TestCase): # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None + device_properties = None @classmethod def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.device_properties = get_device_properties() @slow @require_torch_accelerator @@ -161,32 +161,26 @@ class MixtralIntegrationTest(unittest.TestCase): ) # TODO: might need to tweak it in case the logits do not match on our daily runners # these logits have been obtained with the original megablocks implementation. - # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. - # - # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, + # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4 # considering differences in hardware processing and potential deviations in output. - EXPECTED_LOGITS = { - 7: torch.Tensor([[0.1640, 0.1621, 0.6093], [-0.8906, -0.1640, -0.6093], [0.1562, 0.1250, 0.7226]]).to( - torch_device - ), - 8: torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to( - torch_device - ), - 9: torch.Tensor([[0.1641, 0.1621, 0.6094], [-0.8906, -0.1631, -0.6094], [0.1572, 0.1260, 0.7227]]).to( - torch_device - ), - } + # fmt: off + EXPECTED_LOGITS = Expectations( + { + ("cuda", 7): torch.Tensor([[0.1640, 0.1621, 0.6093], [-0.8906, -0.1640, -0.6093], [0.1562, 0.1250, 0.7226]]).to(torch_device), + ("cuda", 8): torch.Tensor([[0.1631, 0.1621, 0.6094], [-0.8906, -0.1621, -0.6094], [0.1572, 0.1270, 0.7227]]).to(torch_device), + ("rocm", 9): torch.Tensor([[0.1641, 0.1621, 0.6094], [-0.8906, -0.1631, -0.6094], [0.1572, 0.1260, 0.7227]]).to(torch_device), + } + ) + # fmt: on + expected_logit = EXPECTED_LOGITS.get_expectation() + with torch.no_grad(): logits = model(dummy_input).logits logits = logits.float() - torch.testing.assert_close( - logits[0, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3 - ) - torch.testing.assert_close( - logits[1, :3, :3], EXPECTED_LOGITS[self.cuda_compute_capability_major_version], atol=1e-3, rtol=1e-3 - ) + torch.testing.assert_close(logits[0, :3, :3], expected_logit, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(logits[1, :3, :3], expected_logit, atol=1e-3, rtol=1e-3) @slow @require_torch_accelerator @@ -201,33 +195,28 @@ class MixtralIntegrationTest(unittest.TestCase): # TODO: might need to tweak it in case the logits do not match on our daily runners # - # Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4. + # ("cuda", 8) for A100/A10, and ("cuda", 7) for T4. # - # Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s, # considering differences in hardware processing and potential deviations in generated text. - EXPECTED_LOGITS_LEFT_UNPADDED = { - 7: torch.Tensor( - [[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]], - ).to(torch_device), - 8: torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to( - torch_device, - ), - 9: torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to( - torch_device - ), - } + # fmt: off + EXPECTED_LOGITS_LEFT_UNPADDED = Expectations( + { + ("cuda", 7): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2275, 0.6054], [0.2656, -0.7070, 0.2460]]).to(torch_device), + ("cuda", 8): torch.Tensor([[0.2207, 0.5234, -0.3828], [0.8203, -0.2285, 0.6055], [0.2656, -0.7109, 0.2451]]).to(torch_device), + ("rocm", 9): torch.Tensor([[0.2236, 0.5195, -0.3828], [0.8203, -0.2285, 0.6055], [0.2637, -0.7109, 0.2451]]).to(torch_device), + } + ) + expected_left_unpadded = EXPECTED_LOGITS_LEFT_UNPADDED.get_expectation() - EXPECTED_LOGITS_RIGHT_UNPADDED = { - 7: torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to( - torch_device - ), - 8: torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to( - torch_device, - ), - 9: torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to( - torch_device - ), - } + EXPECTED_LOGITS_RIGHT_UNPADDED = Expectations( + { + ("cuda", 7): torch.Tensor([[0.2167, 0.1269, -0.1640], [-0.3496, 0.2988, -1.0312], [0.0688, 0.7929, 0.8007]]).to(torch_device), + ("cuda", 8): torch.Tensor([[0.2178, 0.1270, -0.1621], [-0.3496, 0.3008, -1.0312], [0.0693, 0.7930, 0.7969]]).to(torch_device), + ("rocm", 9): torch.Tensor([[0.2197, 0.1250, -0.1611], [-0.3516, 0.3008, -1.0312], [0.0684, 0.7930, 0.8008]]).to(torch_device), + } + ) + expected_right_unpadded = EXPECTED_LOGITS_RIGHT_UNPADDED.get_expectation() + # fmt: on with torch.no_grad(): logits = model(dummy_input, attention_mask=attention_mask).logits @@ -235,13 +224,13 @@ class MixtralIntegrationTest(unittest.TestCase): torch.testing.assert_close( logits[0, -3:, -3:], - EXPECTED_LOGITS_LEFT_UNPADDED[self.cuda_compute_capability_major_version], + expected_left_unpadded, atol=1e-3, rtol=1e-3, ) torch.testing.assert_close( logits[1, -3:, -3:], - EXPECTED_LOGITS_RIGHT_UNPADDED[self.cuda_compute_capability_major_version], + expected_right_unpadded, atol=1e-3, rtol=1e-3, ) diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py index 9ef543edeb7..d24ab44736e 100644 --- a/tests/models/nemotron/test_modeling_nemotron.py +++ b/tests/models/nemotron/test_modeling_nemotron.py @@ -99,16 +99,6 @@ class NemotronModelTest(CausalLMModelTest, unittest.TestCase): @require_torch_accelerator class NemotronIntegrationTest(unittest.TestCase): - # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) - # Depending on the hardware we get different logits / generations - cuda_compute_capability_major_version = None - - @classmethod - def setUpClass(cls): - if is_torch_available() and torch.cuda.is_available(): - # 8 is for A100 / A10 and 7 for T4 - cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] - @slow @require_read_token def test_nemotron_8b_generation_sdpa(self): diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 03cf79dafa5..b339343627b 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -22,6 +22,7 @@ from packaging import version from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM, StaticCache from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_aqlm, require_torch_gpu, @@ -81,8 +82,6 @@ class AqlmTest(unittest.TestCase): EXPECTED_OUTPUT = "Hello my name is Katie. I am a 20 year old college student. I am a very outgoing person. I love to have fun and be active. I" - device_map = "cuda" - # called only once for all test in this class @classmethod def setUpClass(cls): @@ -92,12 +91,12 @@ class AqlmTest(unittest.TestCase): cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.quantized_model = AutoModelForCausalLM.from_pretrained( cls.model_name, - device_map=cls.device_map, + device_map=torch_device, ) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model_conversion(self): @@ -170,7 +169,7 @@ class AqlmTest(unittest.TestCase): """ with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname) - model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 195480be497..54234499673 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -19,6 +19,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM from transformers.testing_utils import ( backend_empty_cache, + get_device_properties, require_accelerate, require_auto_awq, require_flash_attn, @@ -61,12 +62,10 @@ class AwqConfigTest(unittest.TestCase): # Only cuda and xpu devices can run this function support_llm_awq = False - if torch.cuda.is_available(): - compute_capability = torch.cuda.get_device_capability() - major, minor = compute_capability - if major >= 8: - support_llm_awq = True - elif torch.xpu.is_available(): + device_type, major = get_device_properties() + if device_type == "cuda" and major >= 8: + support_llm_awq = True + elif device_type == "xpu": support_llm_awq = True if support_llm_awq: @@ -357,7 +356,7 @@ class AwqFusedTest(unittest.TestCase): self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear)) @unittest.skipIf( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", ) @require_flash_attn @@ -388,7 +387,7 @@ class AwqFusedTest(unittest.TestCase): @require_flash_attn @require_torch_gpu @unittest.skipIf( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", ) def test_generation_fused_batched(self): @@ -441,7 +440,7 @@ class AwqFusedTest(unittest.TestCase): @require_flash_attn @require_torch_multi_gpu @unittest.skipIf( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", ) def test_generation_custom_model(self): diff --git a/tests/quantization/bitnet_integration/test_bitnet.py b/tests/quantization/bitnet_integration/test_bitnet.py index 10a1843dc17..fabe980ca29 100644 --- a/tests/quantization/bitnet_integration/test_bitnet.py +++ b/tests/quantization/bitnet_integration/test_bitnet.py @@ -23,6 +23,7 @@ from transformers import ( OPTForCausalLM, ) from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_torch_gpu, slow, @@ -56,7 +57,6 @@ class BitNetQuantConfigTest(unittest.TestCase): @require_accelerate class BitNetTest(unittest.TestCase): model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens" - device = "cuda" # called only once for all test in this class @classmethod @@ -65,11 +65,11 @@ class BitNetTest(unittest.TestCase): Load the model """ cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) - cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device) + cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=torch_device) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_replace_with_bitlinear(self): @@ -100,7 +100,7 @@ class BitNetTest(unittest.TestCase): """ input_text = "What are we having for dinner?" expected_output = "What are we having for dinner? What are we going to do for fun this weekend?" - input_ids = self.tokenizer(input_text, return_tensors="pt").to("cuda") + input_ids = self.tokenizer(input_text, return_tensors="pt").to(torch_device) output = self.quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) @@ -127,7 +127,7 @@ class BitNetTest(unittest.TestCase): from transformers.integrations import BitLinear layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32) - layer.to(self.device) + layer.to(torch_device) input_tensor = torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32).to(torch_device) @@ -202,9 +202,8 @@ class BitNetTest(unittest.TestCase): class BitNetSerializationTest(unittest.TestCase): def test_model_serialization(self): model_name = "HF1BitLLM/Llama3-8B-1.58-100B-tokens" - device = "cuda" - quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device) - input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=device) + quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=torch_device) + input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device) with torch.no_grad(): logits_ref = quantized_model.forward(input_tensor).logits @@ -215,10 +214,10 @@ class BitNetSerializationTest(unittest.TestCase): # Remove old model del quantized_model - torch.cuda.empty_cache() + backend_empty_cache(torch_device) # Load and check if the logits match - model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=device) + model_loaded = AutoModelForCausalLM.from_pretrained("quant_model", device_map=torch_device) with torch.no_grad(): logits_loaded = model_loaded.forward(input_tensor).logits diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 5c7ef7a9159..5887445bbc0 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -32,6 +32,7 @@ from transformers.models.opt.modeling_opt import OPTAttention from transformers.testing_utils import ( apply_skip_if_not_implemented, backend_empty_cache, + backend_torch_accelerator_module, is_bitsandbytes_available, is_torch_available, require_accelerate, @@ -376,7 +377,7 @@ class Bnb4BitT5Test(unittest.TestCase): avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 """ gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_inference_without_keep_in_fp32(self): r""" @@ -460,7 +461,7 @@ class Classes4BitModelTest(Base4bitTest): del self.seq_to_seq_model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_correct_head_class(self): r""" @@ -491,7 +492,7 @@ class Pipeline4BitTest(Base4bitTest): del self.pipe gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_pipeline(self): r""" @@ -589,10 +590,10 @@ class Bnb4BitTestTraining(Base4bitTest): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True) - if torch.cuda.is_available(): - self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) - elif torch.xpu.is_available(): - self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"}) + if torch_device in ["cuda", "xpu"]: + self.assertEqual( + set(model.hf_device_map.values()), {backend_torch_accelerator_module(torch_device).current_device()} + ) else: self.assertTrue(all(param.device.type == "cpu" for param in model.parameters())) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 8c718d69f41..5790497a405 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -31,6 +31,8 @@ from transformers import ( from transformers.models.opt.modeling_opt import OPTAttention from transformers.testing_utils import ( apply_skip_if_not_implemented, + backend_empty_cache, + backend_torch_accelerator_module, is_accelerate_available, is_bitsandbytes_available, is_torch_available, @@ -137,7 +139,7 @@ class MixedInt8Test(BaseMixedInt8Test): del self.model_8bit gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_get_keys_to_not_convert(self): r""" @@ -484,7 +486,7 @@ class MixedInt8T5Test(unittest.TestCase): avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27 """ gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_inference_without_keep_in_fp32(self): r""" @@ -599,7 +601,7 @@ class MixedInt8ModelClassesTest(BaseMixedInt8Test): del self.seq_to_seq_model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_correct_head_class(self): r""" @@ -631,7 +633,7 @@ class MixedInt8TestPipeline(BaseMixedInt8Test): del self.pipe gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_pipeline(self): r""" @@ -872,10 +874,10 @@ class MixedInt8TestTraining(BaseMixedInt8Test): model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True) model.train() - if torch.cuda.is_available(): - self.assertEqual(set(model.hf_device_map.values()), {torch.cuda.current_device()}) - elif torch.xpu.is_available(): - self.assertEqual(set(model.hf_device_map.values()), {f"xpu:{torch.xpu.current_device()}"}) + if torch_device in ["cuda", "xpu"]: + self.assertEqual( + set(model.hf_device_map.values()), {backend_torch_accelerator_module(torch_device).current_device()} + ) else: self.assertTrue(all(param.device.type == "cpu" for param in model.parameters())) diff --git a/tests/quantization/compressed_tensors_integration/test_compressed_models.py b/tests/quantization/compressed_tensors_integration/test_compressed_models.py index 074c943431a..f956f0c08c1 100644 --- a/tests/quantization/compressed_tensors_integration/test_compressed_models.py +++ b/tests/quantization/compressed_tensors_integration/test_compressed_models.py @@ -3,7 +3,7 @@ import unittest import warnings from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_compressed_tensors, require_torch +from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device from transformers.utils import is_torch_available from transformers.utils.quantization_config import CompressedTensorsConfig @@ -41,7 +41,7 @@ class StackCompressedModelTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_compressed_uncompressed_model_shapes(self): @@ -160,7 +160,7 @@ class RunCompressedTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_default_run_compressed__True(self): diff --git a/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py b/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py index 47e78498060..d44e560fff0 100644 --- a/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py +++ b/tests/quantization/compressed_tensors_integration/test_compressed_tensors.py @@ -2,7 +2,7 @@ import gc import unittest from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig -from transformers.testing_utils import require_compressed_tensors, require_torch +from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device from transformers.utils import is_torch_available @@ -22,7 +22,7 @@ class CompressedTensorsTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_config_args(self): diff --git a/tests/quantization/eetq_integration/test_eetq.py b/tests/quantization/eetq_integration/test_eetq.py index a5e989ca94e..1bd1fbe45c1 100644 --- a/tests/quantization/eetq_integration/test_eetq.py +++ b/tests/quantization/eetq_integration/test_eetq.py @@ -18,6 +18,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, EetqConfig, OPTForCausalLM from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_eetq, require_torch_gpu, @@ -87,7 +88,7 @@ class EetqTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model_conversion(self): diff --git a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py index d3ecbb671fe..e31bd9adf51 100644 --- a/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py +++ b/tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py @@ -18,6 +18,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_fbgemm_gpu, require_read_token, @@ -126,7 +127,7 @@ class FbgemmFp8Test(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model_conversion(self): diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index b5a586b0302..5622ab252fe 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -19,6 +19,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FineGrainedFP8Config, OPTForCausalLM from transformers.testing_utils import ( backend_empty_cache, + get_device_properties, require_accelerate, require_read_token, require_torch_accelerator, @@ -254,7 +255,7 @@ class FP8LinearTest(unittest.TestCase): device = torch_device @unittest.skipIf( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9, + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9, "Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0", ) def test_linear_preserves_shape(self): @@ -270,7 +271,7 @@ class FP8LinearTest(unittest.TestCase): self.assertEqual(x_.shape, x.shape) @unittest.skipIf( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9, + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 9, "Skipping FP8LinearTest because it is not supported on GPU with capability < 9.0", ) def test_linear_with_diff_feature_size_preserves_shape(self): diff --git a/tests/quantization/higgs/test_higgs.py b/tests/quantization/higgs/test_higgs.py index 65dd151e98a..20727620269 100644 --- a/tests/quantization/higgs/test_higgs.py +++ b/tests/quantization/higgs/test_higgs.py @@ -18,6 +18,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HiggsConfig, OPTForCausalLM from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_flute_hadamard, require_torch_gpu, @@ -87,7 +88,7 @@ class HiggsTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model_conversion(self): diff --git a/tests/quantization/hqq/test_hqq.py b/tests/quantization/hqq/test_hqq.py index a686bbd7de7..5effe1c8616 100755 --- a/tests/quantization/hqq/test_hqq.py +++ b/tests/quantization/hqq/test_hqq.py @@ -17,6 +17,7 @@ import unittest from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_hqq, require_torch_gpu, @@ -50,7 +51,7 @@ class HQQLLMRunner: def cleanup(): - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() @@ -187,7 +188,7 @@ class HQQTestBias(unittest.TestCase): hqq_runner.model.save_pretrained(tmpdirname) del hqq_runner.model - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model_loaded = AutoModelForCausalLM.from_pretrained( tmpdirname, torch_dtype=torch.float16, device_map=torch_device @@ -228,7 +229,7 @@ class HQQSerializationTest(unittest.TestCase): # Remove old model del hqq_runner.model - torch.cuda.empty_cache() + backend_empty_cache(torch_device) # Load and check if the logits match model_loaded = AutoModelForCausalLM.from_pretrained( diff --git a/tests/quantization/spqr_integration/test_spqr.py b/tests/quantization/spqr_integration/test_spqr.py index 961d11478c6..9f7ab7f4b9b 100644 --- a/tests/quantization/spqr_integration/test_spqr.py +++ b/tests/quantization/spqr_integration/test_spqr.py @@ -18,6 +18,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, SpQRConfig, StaticCache from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_spqr, require_torch_gpu, @@ -82,8 +83,6 @@ class SpQRTest(unittest.TestCase): ) EXPECTED_OUTPUT_COMPILE = "Hello my name is Jake and I am a 20 year old student at the University of North Texas. (Go Mean Green!) I am a huge fan of the Dallas" - device_map = "cuda" - # called only once for all test in this class @classmethod def setUpClass(cls): @@ -93,12 +92,12 @@ class SpQRTest(unittest.TestCase): cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) cls.quantized_model = AutoModelForCausalLM.from_pretrained( cls.model_name, - device_map=cls.device_map, + device_map=torch_device, ) def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model_conversion(self): @@ -158,7 +157,7 @@ class SpQRTest(unittest.TestCase): """ with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname) - model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=torch_device) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 8f1c15c94d6..37fc538fdbd 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -21,10 +21,13 @@ from packaging import version from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig from transformers.testing_utils import ( + backend_empty_cache, + get_device_properties, require_torch_gpu, require_torch_multi_gpu, require_torchao, require_torchao_version_greater_or_equal, + torch_device, ) from transformers.utils import is_torch_available, is_torchao_available @@ -131,7 +134,7 @@ class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_int4wo_quant(self): @@ -260,7 +263,7 @@ class TorchAoTest(unittest.TestCase): @require_torch_gpu class TorchAoGPUTest(TorchAoTest): - device = "cuda" + device = torch_device quant_scheme_kwargs = {"group_size": 32} def test_int4wo_offload(self): @@ -397,7 +400,7 @@ class TorchAoSerializationTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_original_model_expected_output(self): @@ -452,33 +455,33 @@ class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): @require_torch_gpu class TorchAoSerializationGPTTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} - device = "cuda:0" + device = f"{torch_device}:0" @require_torch_gpu class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - device = "cuda:0" + device = f"{torch_device}:0" @require_torch_gpu class TorchAoSerializationW8GPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - device = "cuda:0" + device = f"{torch_device}:0" @require_torch_gpu @require_torchao_version_greater_or_equal("0.10.0") class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest): EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - device = "cuda:0" + device = f"{torch_device}:0" # called only once for all test in this class @classmethod def setUpClass(cls): - if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: + if not (get_device_properties()[0] == "cuda" and get_device_properties()[1] >= 9): raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") from torchao.quantization import Float8WeightOnlyConfig @@ -493,12 +496,12 @@ class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest): @require_torchao_version_greater_or_equal("0.10.0") class TorchAoSerializationA8W4Test(TorchAoSerializationTest): EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - device = "cuda:0" + device = f"{torch_device}:0" # called only once for all test in this class @classmethod def setUpClass(cls): - if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: + if not (get_device_properties()[0] == "cuda" and get_device_properties()[1] >= 9): raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") from torchao.quantization import Int8DynamicActivationInt4WeightConfig diff --git a/tests/quantization/vptq_integration/test_vptq.py b/tests/quantization/vptq_integration/test_vptq.py index fdbd703bae3..0f9e03d4b74 100644 --- a/tests/quantization/vptq_integration/test_vptq.py +++ b/tests/quantization/vptq_integration/test_vptq.py @@ -18,6 +18,7 @@ import unittest from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, VptqConfig from transformers.testing_utils import ( + backend_empty_cache, require_accelerate, require_torch_gpu, require_torch_multi_gpu, @@ -74,7 +75,7 @@ class VptqTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def test_quantized_model(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 87fab3f8af7..512f6346e7f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -73,7 +73,10 @@ from transformers.models.auto.modeling_auto import ( ) from transformers.testing_utils import ( CaptureLogger, + backend_device_count, backend_empty_cache, + backend_memory_allocated, + backend_torch_accelerator_module, get_device_properties, hub_retry, is_flaky, @@ -2613,7 +2616,7 @@ class ModelTesterMixin: for k in blacklist_non_batched_params: inputs_dict.pop(k, None) - # move input tensors to cuda:O + # move input tensors to accelerator O for k, v in inputs_dict.items(): if torch.is_tensor(v): inputs_dict[k] = v.to(0) @@ -2636,12 +2639,12 @@ class ModelTesterMixin: # a candidate for testing_utils def get_current_gpu_memory_use(): - """returns a list of cuda memory allocations per GPU in MBs""" + """returns a list of VRAM allocations per GPU in MBs""" per_device_memory = [] - for id in range(torch.cuda.device_count()): - with torch.cuda.device(id): - per_device_memory.append(torch.cuda.memory_allocated() >> 20) + for id in range(backend_device_count(torch_device)): + with backend_torch_accelerator_module(torch_device).device(id): + per_device_memory.append(backend_memory_allocated(torch_device) >> 20) return per_device_memory @@ -2657,7 +2660,7 @@ class ModelTesterMixin: # Put model on device 0 and take a memory snapshot model = model_class(config) - model.to("cuda:0") + model.to(f"{torch_device}:0") memory_after_model_load = get_current_gpu_memory_use() # The memory use on device 0 should be higher than it was initially. @@ -2717,7 +2720,7 @@ class ModelTesterMixin: model.parallelize() - parallel_output = model(**cast_to_device(inputs_dict, "cuda:0")) + parallel_output = model(**cast_to_device(inputs_dict, f"{torch_device}:0")) for value, parallel_value in zip(output, parallel_output): if isinstance(value, torch.Tensor): @@ -4240,10 +4243,10 @@ class ModelTesterMixin: # add position_ids + fa_kwargs data_collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True) batch = data_collator(features) - batch_cuda = {k: t.cuda() if torch.is_tensor(t) else t for k, t in batch.items()} + batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()} res_padded = model(**inputs_dict) - res_padfree = model(**batch_cuda) + res_padfree = model(**batch_accelerator) logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] logits_padfree = res_padfree.logits[0] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 21b8622473f..41bc61e8a7f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -3224,7 +3224,7 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # For more than 1 GPUs, since the randomness is introduced in the model and with DataParallel (which is used # in this test for more than 2 GPUs), the calls to the torch RNG will happen in a random order (sometimes # GPU 0 will call first and sometimes GPU 1). - random_torch = not torch.cuda.is_available() or torch.cuda.device_count() <= 1 + random_torch = not torch.cuda.is_available() or backend_device_count(torch_device) <= 1 if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 3d1fa7a4474..9d435cb7ed1 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -22,6 +22,7 @@ from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATION from transformers.testing_utils import ( CaptureStderr, backend_device_count, + backend_torch_accelerator_module, cleanup, get_gpu_count, is_torch_available, @@ -430,11 +431,7 @@ class CacheHardIntegrationTest(unittest.TestCase): original = GenerationConfig(**common) offloaded = GenerationConfig(cache_implementation="offloaded", **common) - torch_accelerator_module = None - if device.type == "cuda": - torch_accelerator_module = torch.cuda - elif device.type == "xpu": - torch_accelerator_module = torch.xpu + torch_accelerator_module = backend_torch_accelerator_module(device.type) torch_accelerator_module.reset_peak_memory_stats(device) model.generate(generation_config=original, **inputs)