import tempfile import unittest from transformers import LlavaConfig class LlavaConfigTest(unittest.TestCase): def test_llava_reload(self): """ Simple test for reloading default llava configs """ with tempfile.TemporaryDirectory() as tmp_dir: config = LlavaConfig() config.save_pretrained(tmp_dir) reloaded = LlavaConfig.from_pretrained(tmp_dir) assert config.to_dict() == reloaded.to_dict() def test_pixtral_reload(self): """ Simple test for reloading pixtral configs """ vision_config = { "model_type": "pixtral", "head_dim": 64, "hidden_act": "silu", "image_size": 1024, "is_composition": True, "patch_size": 16, "rope_theta": 10000.0, "tie_word_embeddings": False, } text_config = { "model_type": "mistral", "hidden_size": 5120, "head_dim": 128, "num_attention_heads": 32, "intermediate_size": 14336, "is_composition": True, "max_position_embeddings": 1024000, "num_hidden_layers": 40, "num_key_value_heads": 8, "rms_norm_eps": 1e-05, "rope_theta": 1000000000.0, "sliding_window": None, "vocab_size": 131072, } with tempfile.TemporaryDirectory() as tmp_dir: config = LlavaConfig(vision_config=vision_config, text_config=text_config) config.save_pretrained(tmp_dir) reloaded = LlavaConfig.from_pretrained(tmp_dir) assert config.to_dict() == reloaded.to_dict() def test_arbitrary_reload(self): """ Simple test for reloading arbitrarily composed subconfigs """ default_values = LlavaConfig().to_diff_dict() default_values["vision_config"]["model_type"] = "pixtral" default_values["text_config"]["model_type"] = "opt" self.maxDiff = None with tempfile.TemporaryDirectory() as tmp_dir: config = LlavaConfig(**default_values) config.save_pretrained(tmp_dir) reloaded = LlavaConfig.from_pretrained(tmp_dir) self.assertDictEqual(config.to_dict(), reloaded.to_dict())