Add GGUF support to Gemma3 Text backbone (#37424)

* add gemma3 gguf support

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix typo and add gguf limit

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix a typo

Signed-off-by: Isotr0py <2037008807@qq.com>

* add vision conversion test

Signed-off-by: Isotr0py <2037008807@qq.com>

* fix typos

Signed-off-by: Isotr0py <2037008807@qq.com>

---------

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Isotr0py 2025-04-10 23:15:43 +08:00 committed by GitHub
parent 0ea1151222
commit 6daec12d0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 0 deletions

View File

@ -204,6 +204,23 @@ GGUF_CONFIG_MAPPING = {
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"attention.sliding_window": "sliding_window",
"vocab_size": "vocab_size",
},
"gemma3": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
# NOTE: Gemma3 has key_length==value_length==head_dim
# See: https://github.com/ggml-org/llama.cpp/blob/fe5b78c89670b2f37ecb216306bed3e677b49d9f/convert_hf_to_gguf.py#L3495-L3496
"attention.key_length": "head_dim",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"attention.sliding_window": "sliding_window",
"vocab_size": "vocab_size",
},
}
@ -669,6 +686,7 @@ GGUF_TO_FAST_CONVERTERS = {
"mamba": GGUFGPTConverter,
"nemotron": GGUFGPTConverter,
"gemma2": GGUFGemmaConverter,
"gemma3_text": GGUFGemmaConverter,
}

View File

@ -253,6 +253,7 @@ TENSOR_PROCESSORS = {
"mamba": MambaTensorProcessor,
"nemotron": NemotronTensorProcessor,
"gemma2": Gemma2TensorProcessor,
"gemma3": Gemma2TensorProcessor,
}
@ -292,6 +293,8 @@ def get_gguf_hf_weights_map(
model_type = "command-r"
elif model_type == "qwen2_moe":
model_type = "qwen2moe"
elif model_type == "gemma3_text":
model_type = "gemma3"
arch = None
for key, value in MODEL_ARCH_NAMES.items():
if value == model_type:
@ -438,6 +441,10 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_lo
if gguf_key in reader_keys:
logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
# Gemma3 GGUF checkpoint only contains weights of text backbone
if parsed_parameters["config"]["model_type"] == "gemma3":
parsed_parameters["config"]["model_type"] = "gemma3_text"
# retrieve config vocab_size from tokenizer
# Please refer to https://github.com/huggingface/transformers/issues/32526 for more details
if "vocab_size" not in parsed_parameters["config"]:

View File

@ -296,6 +296,10 @@ class GgufModelTests(unittest.TestCase):
nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF"
original_gemma2_model_id = "google/gemma-2-2b-it"
gemma2_model_id = "bartowski/gemma-2-2b-it-GGUF"
original_gemma3_text_model_id = "google/gemma-3-1b-it"
original_gemma3_vision_model_id = "google/gemma-3-4b-it"
gemma3_text_model_id = "unsloth/gemma-3-1b-it-GGUF"
gemma3_vision_model_id = "unsloth/gemma-3-4b-it-GGUF"
q4_0_phi3_model_id = "Phi-3-mini-4k-instruct-q4.gguf"
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
@ -325,6 +329,9 @@ class GgufModelTests(unittest.TestCase):
q3_k_gemma2_model_id = "gemma-2-2b-it-Q3_K_L.gguf"
q8_0_gemma2_model_id = "gemma-2-2b-it-Q8_0.gguf"
fp32_gemma2_model_id = "gemma-2-2b-it-f32.gguf"
q2_k_gemma3_text_model_id = "gemma-3-1b-it-Q2_K.gguf"
bf16_gemma3_text_model_id = "gemma-3-1b-it-BF16.gguf"
bf16_gemma3_vision_model_id = "gemma-3-4b-it-BF16.gguf"
example_text = "Hello"
@ -881,3 +888,67 @@ class GgufModelTests(unittest.TestCase):
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
@unittest.skipUnless(is_gguf_available("0.16.0"), "test requires gguf version >= 0.16.0")
def test_gemma3_text_q2_k(self):
model = AutoModelForCausalLM.from_pretrained(
self.gemma3_text_model_id,
gguf_file=self.q2_k_gemma3_text_model_id,
torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(self.gemma3_text_model_id, gguf_file=self.q2_k_gemma3_text_model_id)
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
out = model.generate(text, max_new_tokens=10)
EXPECTED_TEXT = "Hello,\n\nI'm looking for a small,"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
@require_read_token
@unittest.skipUnless(is_gguf_available("0.16.0"), "test requires gguf version >= 0.16.0")
def test_gemma3_text_weights_conversion_bf16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.original_gemma3_text_model_id,
torch_dtype=torch.float16,
)
converted_model = AutoModelForCausalLM.from_pretrained(
self.gemma3_text_model_id,
gguf_file=self.bf16_gemma3_text_model_id,
torch_dtype=torch.float16,
)
converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()
for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
# Test text backbone conversion for gemma3 vision models
@require_read_token
@unittest.skipUnless(is_gguf_available("0.16.0"), "test requires gguf version >= 0.16.0")
def test_gemma3_vision_weights_conversion_bf16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.original_gemma3_vision_model_id,
torch_dtype=torch.float16,
).language_model
converted_model = AutoModelForCausalLM.from_pretrained(
self.gemma3_vision_model_id,
gguf_file=self.bf16_gemma3_vision_model_id,
torch_dtype=torch.float16,
)
converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()
for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")