enable 6 gemma2 cases on XPU (#37564)

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-04-18 18:10:34 +08:00 committed by GitHub
parent 049b75ea72
commit 3cd6627cd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,6 +26,7 @@ from transformers.testing_utils import (
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
slow,
tooslow,
@ -155,7 +156,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
@slow
@require_torch_gpu
@require_torch_accelerator
class Gemma2IntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
@ -360,6 +361,10 @@ class Gemma2IntegrationTest(unittest.TestCase):
we need to correctly slice the attention mask in all cases (because we use a HybridCache).
Outputs for every attention functions should be coherent and identical.
"""
if torch_device == "xpu" and attn_implementation == "flash_attention_2":
self.skipTest(reason="Intel XPU doesn't support falsh_attention_2 as of now.")
model_id = "google/gemma-2-2b"
EXPECTED_COMPLETIONS = [
" the people, the food, the culture, the history, the music, the art, the architecture",