transformers/tests/models/llava/test_configuration_llava.py
Kyle Sayers bcfc9d795e
[Bugfix] Fix reloading of pixtral/llava configs (#36077)
* add is_composition flag to LlavaConfig

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* WIP: pixtral text config

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* fix style

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* add test

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* use is_composition for pixtral

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* Revert "use is_composition for pixtral"

This reverts commit a53d5f9fc5.

* Revert "Revert "use is_composition for pixtral""

This reverts commit 3ab1c99404.

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
2025-02-14 15:27:05 +01:00

71 lines
2.2 KiB
Python

import tempfile
import unittest
from transformers import LlavaConfig
class LlavaConfigTest(unittest.TestCase):
def test_llava_reload(self):
"""
Simple test for reloading default llava configs
"""
with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig()
config.save_pretrained(tmp_dir)
reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()
def test_pixtral_reload(self):
"""
Simple test for reloading pixtral configs
"""
vision_config = {
"model_type": "pixtral",
"head_dim": 64,
"hidden_act": "silu",
"image_size": 1024,
"is_composition": True,
"patch_size": 16,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
}
text_config = {
"model_type": "mistral",
"hidden_size": 5120,
"head_dim": 128,
"num_attention_heads": 32,
"intermediate_size": 14336,
"is_composition": True,
"max_position_embeddings": 1024000,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"vocab_size": 131072,
}
with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(vision_config=vision_config, text_config=text_config)
config.save_pretrained(tmp_dir)
reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()
def test_arbitrary_reload(self):
"""
Simple test for reloading arbirarily composed subconfigs
"""
default_values = LlavaConfig().to_dict()
default_values["vision_config"]["model_type"] = "qwen2_vl"
default_values["text_config"]["model_type"] = "opt"
with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(**default_values)
config.save_pretrained(tmp_dir)
reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()