fix tests

This commit is contained in:
Cyril Vallez 2025-07-03 16:55:26 +02:00
parent 45e885bbdd
commit f0a3d28ded
No known key found for this signature in database

View File

@ -14,14 +14,14 @@
import unittest
from transformers.testing_utils import is_torch_available, require_read_token, require_torch
from transformers.testing_utils import is_torch_available, require_torch
if is_torch_available():
import torch
from torch.nn.attention.flex_attention import create_block_mask
from transformers import AutoConfig
from transformers import LlamaConfig
from transformers.masking_utils import create_causal_mask
@ -54,10 +54,9 @@ EXPECTED_PACKED_MASK = torch.tensor([[[
@require_torch
@require_read_token
class MaskTest(unittest.TestCase):
def test_packed_sequence_mask_sdpa(self):
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config = LlamaConfig()
config._attn_implementation = "sdpa"
batch_size = 2
@ -77,10 +76,10 @@ class MaskTest(unittest.TestCase):
position_ids=position_ids,
)
self.assertEqual(causal_mask, EXPECTED_PACKED_MASK)
self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all())
def test_packed_sequence_mask_eager(self):
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config = LlamaConfig()
config._attn_implementation = "eager"
batch_size = 2
@ -101,10 +100,10 @@ class MaskTest(unittest.TestCase):
)
min_dtype = torch.finfo(torch.float16).min
self.assertEqual(causal_mask, torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype))
self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all())
def test_packed_sequence_mask_flex_attention(self):
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config = LlamaConfig()
config._attn_implementation = "flex_attention"
batch_size = 2