diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index 2ef760e0fbe..f41d3ab6e32 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -22,9 +22,7 @@ from transformers import set_seed from transformers.testing_utils import ( is_flaky, require_flash_attn, - require_torch_accelerator, require_torch_gpu, - require_torch_sdpa, slow, ) @@ -410,39 +408,6 @@ class CausalLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM with self.assertRaises(AssertionError): torch.testing.assert_close(yarn_sin_long, original_sin_long) - @require_torch_sdpa - @require_torch_accelerator - @slow - def test_sdpa_equivalence(self): - for model_class in self.all_model_classes: - if not model_class._supports_sdpa: - self.skipTest(reason="Model does not support SDPA") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model_sdpa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="sdpa" - ) - model_sdpa.to(torch_device) - - model = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager" - ) - model.to(torch_device) - - dummy_input = inputs_dict[model_class.main_input_name] - dummy_input = dummy_input.to(torch_device) - outputs = model(dummy_input, output_hidden_states=True) - outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) - - logits = outputs.hidden_states[-1] - logits_sdpa = outputs_sdpa.hidden_states[-1] - - assert torch.allclose(logits_sdpa, logits, atol=2e-3) - @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test