mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Remove redundant test_sdpa_equivalence test (#38436)
* Remove redundant test * make fixup
This commit is contained in:
parent
51e0fac29f
commit
0ed6f7e6b4
@ -22,9 +22,7 @@ from transformers import set_seed
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_flaky,
|
is_flaky,
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_torch_accelerator,
|
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_sdpa,
|
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -410,39 +408,6 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||||
|
|
||||||
@require_torch_sdpa
|
|
||||||
@require_torch_accelerator
|
|
||||||
@slow
|
|
||||||
def test_sdpa_equivalence(self):
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
if not model_class._supports_sdpa:
|
|
||||||
self.skipTest(reason="Model does not support SDPA")
|
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
||||||
model = model_class(config)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
||||||
model.save_pretrained(tmpdirname)
|
|
||||||
model_sdpa = model_class.from_pretrained(
|
|
||||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
|
||||||
)
|
|
||||||
model_sdpa.to(torch_device)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
|
||||||
)
|
|
||||||
model.to(torch_device)
|
|
||||||
|
|
||||||
dummy_input = inputs_dict[model_class.main_input_name]
|
|
||||||
dummy_input = dummy_input.to(torch_device)
|
|
||||||
outputs = model(dummy_input, output_hidden_states=True)
|
|
||||||
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
|
|
||||||
|
|
||||||
logits = outputs.hidden_states[-1]
|
|
||||||
logits_sdpa = outputs_sdpa.hidden_states[-1]
|
|
||||||
|
|
||||||
assert torch.allclose(logits_sdpa, logits, atol=2e-3)
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@pytest.mark.flash_attn_test
|
@pytest.mark.flash_attn_test
|
||||||
|
Loading…
Reference in New Issue
Block a user