mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bug with padding mask + add corresponding test
This commit is contained in:
parent
3b0d2fa30e
commit
da10de8466
@ -127,9 +127,9 @@ def build_lm_labels(sequence, pad_token):
|
||||
def build_mask(sequence, pad_token):
|
||||
""" Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1. """
|
||||
mask = sequence.clone()
|
||||
mask[mask != pad_token] = 1
|
||||
mask[mask == pad_token] = 0
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = (sequence == pad_token)
|
||||
mask[idx_pad_tokens] = 0
|
||||
return mask
|
||||
|
||||
|
||||
|
@ -116,6 +116,13 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
build_mask(sequence, 23).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(
|
||||
build_mask(sequence, 1).numpy(), expected.numpy()
|
||||
)
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor(
|
||||
|
Loading…
Reference in New Issue
Block a user