transformers/tests/utils/test_masking_utils.py
Cyril Vallez 554be7dac3
style
2025-07-03 16:06:19 +02:00

127 lines
5.4 KiB
Python

# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import is_torch_available, require_torch
if is_torch_available():
import torch
from transformers import AutoConfig
from transformers.masking_utils import create_causal_mask
# fmt: off
EXPECTED_PACKED_MASK = torch.tensor([[[
[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[False, False, False, False, True, False, False, False, False, False],
[False, False, False, False, True, True, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, True, False, False],
[False, False, False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]]],
[[[ True, False, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[False, False, False, False, False, False, True, False, False, False],
[False, False, False, False, False, False, True, True, False, False],
[False, False, False, False, False, False, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]
]]], dtype=torch.bool)
# fmt: on
@require_torch
class MaskTest(unittest.TestCase):
def test_packed_sequence_mask_sdpa(self):
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config._attn_implementation = "sdpa"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
self.assertEqual(causal_mask, EXPECTED_PACKED_MASK)
def test_packed_sequence_mask_eager(self):
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config._attn_implementation = "eager"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
min_dtype = torch.finfo(torch.float16).min
self.assertEqual(causal_mask, torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype))
def test_packed_sequence_mask_flex_attention(self):
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config._attn_implementation = "flex_attention"
batch_size = 2
sequence_length = 10
cache_position = torch.arange(sequence_length)
# First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]])
causal_mask = create_causal_mask(
config=config,
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
attention_mask=None,
cache_position=cache_position,
past_key_values=None,
position_ids=position_ids,
)
block_mask = causal_mask.to_dense()
self.assertEqual(block_mask, EXPECTED_PACKED_MASK)