mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix tests
This commit is contained in:
parent
45e885bbdd
commit
f0a3d28ded
@ -14,14 +14,14 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import is_torch_available, require_read_token, require_torch
|
||||
from transformers.testing_utils import is_torch_available, require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers import LlamaConfig
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
|
||||
|
||||
@ -54,10 +54,9 @@ EXPECTED_PACKED_MASK = torch.tensor([[[
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_read_token
|
||||
class MaskTest(unittest.TestCase):
|
||||
def test_packed_sequence_mask_sdpa(self):
|
||||
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "sdpa"
|
||||
|
||||
batch_size = 2
|
||||
@ -77,10 +76,10 @@ class MaskTest(unittest.TestCase):
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
self.assertEqual(causal_mask, EXPECTED_PACKED_MASK)
|
||||
self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all())
|
||||
|
||||
def test_packed_sequence_mask_eager(self):
|
||||
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
batch_size = 2
|
||||
@ -101,10 +100,10 @@ class MaskTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
min_dtype = torch.finfo(torch.float16).min
|
||||
self.assertEqual(causal_mask, torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype))
|
||||
self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all())
|
||||
|
||||
def test_packed_sequence_mask_flex_attention(self):
|
||||
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
config = LlamaConfig()
|
||||
config._attn_implementation = "flex_attention"
|
||||
|
||||
batch_size = 2
|
||||
|
Loading…
Reference in New Issue
Block a user