mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
[Gemma] Fix eager attention (#29187)
* fix modelling code * add tests * fix tests * add some logit tests * style * fix fix
This commit is contained in:
parent
fc37f38915
commit
2a9b1f80c4
@ -276,7 +276,7 @@ class GemmaAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, -1)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
|
@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
|||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
|
require_torch_sdpa,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@ -460,6 +461,71 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
def test_flash_attn_2_inference_padding_right(self):
|
def test_flash_attn_2_inference_padding_right(self):
|
||||||
self.skipTest("Gemma flash attention does not support right padding")
|
self.skipTest("Gemma flash attention does not support right padding")
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
def test_sdpa_equivalence(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_sdpa:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_sdpa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
model_sdpa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
dummy_input = dummy_input.to(torch_device)
|
||||||
|
outputs = model(dummy_input, output_hidden_states=True)
|
||||||
|
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_sdpa = outputs_sdpa.hidden_states[-1]
|
||||||
|
|
||||||
|
# gemma sdpa needs a high tolerance
|
||||||
|
assert torch.allclose(logits_sdpa, logits, atol=3e-3)
|
||||||
|
|
||||||
|
@require_flash_attn
|
||||||
|
@require_torch_gpu
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@slow
|
||||||
|
def test_flash_attn_2_equivalence(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class._supports_flash_attn_2:
|
||||||
|
return
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_fa = model_class.from_pretrained(
|
||||||
|
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||||
|
)
|
||||||
|
model_fa.to(torch_device)
|
||||||
|
|
||||||
|
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
dummy_input = dummy_input.to(torch_device)
|
||||||
|
outputs = model(dummy_input, output_hidden_states=True)
|
||||||
|
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||||
|
|
||||||
|
logits = outputs.hidden_states[-1]
|
||||||
|
logits_fa = outputs_fa.hidden_states[-1]
|
||||||
|
|
||||||
|
# gemma flash attention 2 needs a high tolerance
|
||||||
|
assert torch.allclose(logits_fa, logits, atol=3e-3)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
@ -542,6 +608,69 @@ class GemmaIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
def test_model_2b_eager(self):
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
EXPECTED_TEXTS = [
|
||||||
|
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||||
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||||
|
]
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@require_torch_sdpa
|
||||||
|
def test_model_2b_sdpa(self):
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
EXPECTED_TEXTS = [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||||
|
]
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@pytest.mark.flash_attn_test
|
||||||
|
@require_flash_attn
|
||||||
|
def test_model_2b_flash_attn(self):
|
||||||
|
model_id = "google/gemma-2b"
|
||||||
|
EXPECTED_TEXTS = [
|
||||||
|
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||||
|
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||||
|
]
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||||
|
)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||||
|
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
def test_model_2b_4bit(self):
|
def test_model_2b_4bit(self):
|
||||||
model_id = "google/gemma-2b"
|
model_id = "google/gemma-2b"
|
||||||
|
Loading…
Reference in New Issue
Block a user