diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 4a9132fcf6a..56156334e25 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -13,6 +13,7 @@ # limitations under the License. import importlib import os +import re import tempfile import unittest @@ -385,7 +386,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # Delete remaining adapter model.delete_adapter("adapter_2") - self.assertNotIn("adapter_2", model.peft_config) + self.assertFalse(hasattr(model, "peft_config")) self.assertFalse(model._hf_peft_config_loaded) # Re-add adapters for edge case tests @@ -394,11 +395,16 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # Attempt to delete multiple adapters at once model.delete_adapter(["adapter_1", "adapter_2"]) - self.assertNotIn("adapter_1", model.peft_config) - self.assertNotIn("adapter_2", model.peft_config) + self.assertFalse(hasattr(model, "peft_config")) self.assertFalse(model._hf_peft_config_loaded) # Test edge cases + msg = re.escape("No adapter loaded. Please load an adapter first.") + with self.assertRaisesRegex(ValueError, msg): + model.delete_adapter("nonexistent_adapter") + + model.add_adapter(peft_config_1, adapter_name="adapter_1") + with self.assertRaisesRegex(ValueError, "The following adapter\\(s\\) are not present"): model.delete_adapter("nonexistent_adapter") @@ -406,16 +412,11 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): model.delete_adapter(["adapter_1", "nonexistent_adapter"]) # Deleting with an empty list or None should not raise errors - model.add_adapter(peft_config_1, adapter_name="adapter_1") model.add_adapter(peft_config_2, adapter_name="adapter_2") model.delete_adapter([]) # No-op self.assertIn("adapter_1", model.peft_config) self.assertIn("adapter_2", model.peft_config) - model.delete_adapter(None) # No-op - self.assertIn("adapter_1", model.peft_config) - self.assertIn("adapter_2", model.peft_config) - # Deleting duplicate adapter names in the list model.delete_adapter(["adapter_1", "adapter_1"]) self.assertNotIn("adapter_1", model.peft_config) @@ -832,7 +833,6 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # Input text for testing text = "Who is a Elon Musk?" - expected_error_msg = "The model 'PeftModel' is not supported for text-generation" model = AutoModelForCausalLM.from_pretrained( BASE_PATH, @@ -849,7 +849,7 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # Create pipeline with PEFT model while capturing log output # Check that the warning message is not present in the logs pipeline_logger = logging.get_logger("transformers.pipelines.base") - with self.assertNoLogs(pipeline_logger, logging.ERROR) as cl: + with self.assertNoLogs(pipeline_logger, logging.ERROR): lora_generator = pipeline( task="text-generation", model=lora_model, @@ -859,8 +859,3 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # Generate text to verify pipeline works _ = lora_generator(text) - - # Check that the warning message is not present in the logs - self.assertNotIn( - expected_error_msg, cl.out, f"Error message '{expected_error_msg}' should not appear when using PeftModel" - )