mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Fix mask creations of GPTNeoX
and GPT2
(#31944)
* fix mask creation of gpt2 and gpt_neox caused by me * forgot the reshape of masks when shape > 2 * add tests for gpt neox and gpt2 * nit on a comment
This commit is contained in:
parent
2782aadae2
commit
605f3245dc
@ -1030,18 +1030,18 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
# Attention mask.
|
||||
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif _use_sdpa:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, input_shape[-1]),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif _use_sdpa:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, input_shape[-1]),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
|
@ -824,25 +824,23 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
inputs_embeds = self.embed_in(input_ids)
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if 0 in attention_mask else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._attn_implementation == "sdpa" and not output_attentions and head_mask is None:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
else:
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
|
@ -426,6 +426,36 @@ class GPT2ModelTester:
|
||||
self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001)
|
||||
self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01)
|
||||
|
||||
def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
|
||||
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
|
||||
model = GPT2Model(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# We want this for SDPA, eager works with a `None` attention mask
|
||||
assert (
|
||||
model.config._attn_implementation == "sdpa"
|
||||
), "This test assumes the model to have the SDPA implementation for its attention calculations."
|
||||
|
||||
# Prepare cache and non_cache input, needs a full attention mask
|
||||
cached_len = input_ids.shape[-1] // 2
|
||||
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
|
||||
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
|
||||
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
cache_outputs = model(**cache_inputs)
|
||||
full_outputs_with_attention_mask = model(
|
||||
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
full_outputs_without_attention_mask = model(
|
||||
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
@ -570,6 +600,10 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs)
|
||||
|
||||
def test_cached_forward_with_and_without_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
@ -219,6 +219,36 @@ class GPTNeoXModelTester:
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def create_and_check_cached_forward_with_and_without_attention_mask(self, config, input_ids, *args):
|
||||
# Relevant issue: https://github.com/huggingface/transformers/issues/31943
|
||||
model = GPTNeoXModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# We want this for SDPA, eager works with a `None` attention mask
|
||||
assert (
|
||||
model.config._attn_implementation == "sdpa"
|
||||
), "This test assumes the model to have the SDPA implementation for its attention calculations."
|
||||
|
||||
# Prepare cache and non_cache input, needs a full attention mask
|
||||
cached_len = input_ids.shape[-1] // 2
|
||||
input_mask = torch.ones(size=input_ids.size()).to(torch_device)
|
||||
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
|
||||
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
cache_outputs = model(**cache_inputs)
|
||||
full_outputs_with_attention_mask = model(
|
||||
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
full_outputs_without_attention_mask = model(
|
||||
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
|
||||
).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(full_outputs_with_attention_mask, full_outputs_without_attention_mask, atol=1e-5)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, input_ids, input_mask, token_labels = config_and_inputs
|
||||
@ -300,6 +330,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_cached_forward_with_and_without_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_cached_forward_with_and_without_attention_mask(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="Feed forward chunking is not implemented")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user