diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 3d396d8f7f3..3a6093e6375 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -26,6 +26,7 @@ from transformers.testing_utils import ( require_flash_attn, require_read_token, require_torch, + require_torch_accelerator, require_torch_gpu, slow, tooslow, @@ -155,7 +156,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): @slow -@require_torch_gpu +@require_torch_accelerator class Gemma2IntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) @@ -360,6 +361,10 @@ class Gemma2IntegrationTest(unittest.TestCase): we need to correctly slice the attention mask in all cases (because we use a HybridCache). Outputs for every attention functions should be coherent and identical. """ + + if torch_device == "xpu" and attn_implementation == "flash_attention_2": + self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.") + model_id = "google/gemma-2-2b" EXPECTED_COMPLETIONS = [ " the people, the food, the culture, the history, the music, the art, the architecture",