mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
tests for deci
This commit is contained in:
parent
c649f6ed31
commit
bb8e9d726b
@ -333,6 +333,10 @@ class GgufModelTests(unittest.TestCase):
|
||||
q4_0_gemma3_qat_model_id = "gemma-3-1b-it-q4_0.gguf"
|
||||
bf16_gemma3_text_model_id = "gemma-3-1b-it-BF16.gguf"
|
||||
bf16_gemma3_vision_model_id = "gemma-3-4b-it-BF16.gguf"
|
||||
deci_original_model_id = "Deci/DeciLM-7B"
|
||||
deci_model_id = "Deci/DeciLM-7B-instruct-GGUF"
|
||||
q8_0_deci_model_id = "decilm-7b-uniform-gqa-q8_0.gguf"
|
||||
fp16_deci_model_id = "decilm-7b-uniform-gqa-f16.gguf"
|
||||
|
||||
example_text = "Hello"
|
||||
|
||||
@ -955,3 +959,91 @@ class GgufModelTests(unittest.TestCase):
|
||||
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
|
||||
else:
|
||||
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
|
||||
|
||||
def test_deci_q8_0(self):
|
||||
"""Test Deci model loading and inference with Q4_0 quantization."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.deci_model_id, gguf_file=self.q8_0_deci_model_id)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.deci_model_id,
|
||||
gguf_file=self.q8_0_deci_model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
|
||||
out = model.generate(**text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "Hello, I am a language model developed"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_deci_weights_conversion_fp16(self):
|
||||
"""Test that GGUF Deci model weights match the original model weights."""
|
||||
original_model_id = "Deci/DeciLM-7B"
|
||||
|
||||
try:
|
||||
original_model = AutoModelForCausalLM.from_pretrained(
|
||||
original_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest(f"Original Deci model not available for comparison: {e}")
|
||||
return
|
||||
|
||||
# You need to have an FP16 version of your GGUF model for accurate comparison
|
||||
try:
|
||||
converted_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.deci_model_id,
|
||||
gguf_file=self.fp16_deci_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
except Exception as e:
|
||||
self.skipTest(f"GGUF FP16 model not available: {e}")
|
||||
return
|
||||
|
||||
converted_state_dict = converted_model.state_dict()
|
||||
original_state_dict = original_model.state_dict()
|
||||
|
||||
for layer_name, original_params in original_state_dict.items():
|
||||
if layer_name in converted_state_dict:
|
||||
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
|
||||
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
|
||||
else:
|
||||
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
|
||||
|
||||
def test_deci_config_mapping(self):
|
||||
"""Test that Deci GGUF config mapping is correctly applied."""
|
||||
from transformers.integrations.ggml import GGUF_CONFIG_MAPPING
|
||||
|
||||
self.assertIn("deci", GGUF_CONFIG_MAPPING)
|
||||
|
||||
deci_mapping = GGUF_CONFIG_MAPPING["deci"]
|
||||
|
||||
expected_mappings = {
|
||||
"context_length": "max_position_embeddings",
|
||||
"block_count": "num_hidden_layers",
|
||||
"feed_forward_length": "intermediate_size",
|
||||
"embedding_length": "hidden_size",
|
||||
"rope.freq_base": "rope_theta",
|
||||
"attention.head_count": "num_attention_heads",
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
}
|
||||
|
||||
for gguf_key, transformers_key in expected_mappings.items():
|
||||
self.assertEqual(deci_mapping[gguf_key], transformers_key)
|
||||
|
||||
self.assertIsNone(deci_mapping["rope.dimension_count"])
|
||||
|
||||
def test_deci_architecture_mapping(self):
|
||||
"""Test that Deci architectures are mapped to GGUFLlamaConverter."""
|
||||
from transformers.integrations.ggml import GGUF_TO_FAST_CONVERTERS, GGUFLlamaConverter
|
||||
|
||||
self.assertIn("deci", GGUF_TO_FAST_CONVERTERS)
|
||||
self.assertIn("decilm", GGUF_TO_FAST_CONVERTERS)
|
||||
|
||||
self.assertEqual(GGUF_TO_FAST_CONVERTERS["deci"], GGUFLlamaConverter)
|
||||
self.assertEqual(GGUF_TO_FAST_CONVERTERS["decilm"], GGUFLlamaConverter)
|
||||
|
Loading…
Reference in New Issue
Block a user