mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Llama: allow custom 4d masks (#29618)
This commit is contained in:
parent
88a4f68fe5
commit
1e21c4fbe0
@ -975,11 +975,16 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
elif attention_mask.dim() == 4:
|
||||
mask_shape = attention_mask.shape
|
||||
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
||||
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
|
@ -1083,11 +1083,16 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
|
||||
causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
if attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
elif attention_mask.dim() == 4:
|
||||
mask_shape = attention_mask.shape
|
||||
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
||||
causal_mask[: mask_shape[0], : mask_shape[1], : mask_shape[2], : mask_shape[3]] = mask_slice
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
|
@ -1992,6 +1992,8 @@ class Mask4DTestBase(unittest.TestCase):
|
||||
# [ 1, 278, 6635, 750],
|
||||
# [ 1, 278, 6635, 338]], device='cuda:0')
|
||||
|
||||
position_ids_0 = torch.tensor([[0, 1, 2, 3]] * 3, device=torch_device, dtype=torch.int64)
|
||||
|
||||
# Combining common prefix with the unique ending tokens:
|
||||
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
|
||||
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
|
||||
@ -2017,81 +2019,63 @@ class Mask4DTestBase(unittest.TestCase):
|
||||
# Creating a position_ids tensor. note the repeating figures in the end.
|
||||
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)
|
||||
|
||||
return input_0, input_1, mask_1, position_ids_1
|
||||
return input_0, position_ids_0, input_1, mask_1, position_ids_1
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mask4DTestFP32(Mask4DTestBase):
|
||||
def setUp(self):
|
||||
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
|
||||
model_dtype = torch.float32
|
||||
self.model_dtype = torch.float32
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def test_attention(self):
|
||||
"""comparing outputs of attention layer"""
|
||||
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
causal_mask_1 = (1 - mask_1).to(self.model_dtype) * torch.finfo(self.model_dtype).min
|
||||
|
||||
hid_0 = self.model.model.embed_tokens(input_0)
|
||||
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
|
||||
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
|
||||
# outs_0.shape == torch.Size([3, 4, 768])
|
||||
|
||||
hid_1 = self.model.model.embed_tokens(input_1)
|
||||
outs_1 = self.model.model.layers[0].self_attn.forward(
|
||||
hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
|
||||
hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
|
||||
)[0]
|
||||
# outs_1.shape == torch.Size([1, 6, 768])
|
||||
|
||||
outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
|
||||
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
|
||||
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)
|
||||
|
||||
def test_inner_model(self):
|
||||
"""comparing hidden outputs of whole inner model"""
|
||||
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
logits_0 = self.model.forward(input_0).logits
|
||||
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||
|
||||
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||
torch.testing.assert_close(
|
||||
logits_0_last_tokens,
|
||||
logits_1_last_tokens,
|
||||
)
|
||||
torch.testing.assert_close(outs_0_last_tokens, outs_1_last_tokens)
|
||||
|
||||
def test_causal_model_logits(self):
|
||||
"""comparing logits outputs of whole inner model"""
|
||||
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
logits_0 = self.model.forward(input_0).logits
|
||||
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
|
||||
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||
|
||||
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
|
||||
torch.testing.assert_close(
|
||||
logits_0_last_tokens,
|
||||
logits_1_last_tokens,
|
||||
)
|
||||
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class Mask4DTestFP16(Mask4DTestBase):
|
||||
test_attention = Mask4DTestFP32.test_attention
|
||||
|
||||
def setUp(self):
|
||||
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
|
||||
model_dtype = torch.float16
|
||||
self.model_dtype = torch.float16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=self.model_dtype).to(torch_device)
|
||||
|
||||
def test_causal_model_logits(self):
|
||||
"""comparing logits outputs of whole inner model"""
|
||||
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
input_0, position_ids_0, input_1, mask_1, position_ids_1 = self.get_test_data()
|
||||
|
||||
logits_0 = self.model.forward(input_0).logits
|
||||
logits_0 = self.model.forward(input_0, position_ids=position_ids_0).logits
|
||||
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits
|
||||
|
||||
logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
|
||||
|
Loading…
Reference in New Issue
Block a user