mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add Nemotron GGUF Loading Support (#34725)
* Add Nemotron GGUF Loading Support * fix the Nemotron architecture assignation --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
d4e1acbb7c
commit
c57eafdaa1
@ -87,6 +87,7 @@ For now the supported model architectures are the architectures that have been v
|
||||
- Starcoder2
|
||||
- T5
|
||||
- Mamba
|
||||
- Nemotron
|
||||
|
||||
## Example usage
|
||||
|
||||
|
@ -248,6 +248,20 @@ GGUF_TENSOR_MAPPING = {
|
||||
"output_norm": "backbone.norm_f",
|
||||
"output.weight": "lm_head.weight",
|
||||
},
|
||||
"nemotron": {
|
||||
"token_embd": "model.embed_tokens",
|
||||
"blk": "model.layers",
|
||||
"ffn_up": "mlp.up_proj",
|
||||
"ffn_down": "mlp.down_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"attn_norm": "input_layernorm",
|
||||
"attn_q": "self_attn.q_proj",
|
||||
"attn_v": "self_attn.v_proj",
|
||||
"attn_k": "self_attn.k_proj",
|
||||
"attn_output": "self_attn.o_proj",
|
||||
"output.weight": "lm_head.weight",
|
||||
"output_norm": "model.norm",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -397,6 +411,18 @@ GGUF_CONFIG_MAPPING = {
|
||||
"ssm.time_step_rank": "time_step_rank",
|
||||
"ssm.inner_size": "intermediate_size",
|
||||
},
|
||||
"nemotron": {
|
||||
"context_length": "max_position_embeddings",
|
||||
"block_count": "num_hidden_layers",
|
||||
"feed_forward_length": "intermediate_size",
|
||||
"embedding_length": "hidden_size",
|
||||
"rope.dimension_count": None,
|
||||
"rope.freq_base": "rope_theta",
|
||||
"attention.head_count": "num_attention_heads",
|
||||
"attention.head_count_kv": "num_key_value_heads",
|
||||
"attention.layer_norm_rms_epsilon": "norm_eps",
|
||||
"vocab_size": "vocab_size",
|
||||
},
|
||||
}
|
||||
|
||||
GGUF_TOKENIZER_MAPPING = {
|
||||
@ -793,6 +819,7 @@ GGUF_TO_FAST_CONVERTERS = {
|
||||
"starcoder2": GGUFGPTConverter,
|
||||
"t5": GGUFT5Converter,
|
||||
"mamba": GGUFGPTConverter,
|
||||
"nemotron": GGUFGPTConverter,
|
||||
}
|
||||
|
||||
|
||||
|
@ -61,6 +61,8 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
starcoder2_original_model_id = "bigcode/starcoder2-3b"
|
||||
mamba_original_model_id = "state-spaces/mamba-2.8b-hf"
|
||||
mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF"
|
||||
nemotron_original_model_id = "nvidia/Nemotron-Mini-4B-Instruct"
|
||||
nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF"
|
||||
|
||||
# standard quants
|
||||
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
|
||||
@ -106,6 +108,8 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf"
|
||||
q6_k_mamba_model_id = "ggml-model-Q6_K.gguf"
|
||||
fp16_mamba_model_id = "ggml-model-f16.gguf"
|
||||
q6_k_nemotron_model_id = "Nemotron-Mini-4B-Instruct-Q6_K.gguf"
|
||||
fp16_nemotron_model_id = "Nemotron-Mini-4B-Instruct-f16.gguf"
|
||||
|
||||
example_text = "Hello"
|
||||
|
||||
@ -792,6 +796,42 @@ class GgufIntegrationTests(unittest.TestCase):
|
||||
EXPECTED_TEXT = "Hello,I answerthe question.\n\nA"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_nemotron_weights_conversion_fp16(self):
|
||||
original_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.nemotron_original_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
converted_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.nemotron_model_id,
|
||||
gguf_file=self.fp16_nemotron_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
converted_state_dict = converted_model.state_dict()
|
||||
original_state_dict = original_model.state_dict()
|
||||
|
||||
for layer_name, original_params in original_state_dict.items():
|
||||
if layer_name in converted_state_dict:
|
||||
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
|
||||
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
|
||||
else:
|
||||
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")
|
||||
|
||||
def test_nemotron_q6_k(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.nemotron_model_id,
|
||||
gguf_file=self.q6_k_nemotron_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.nemotron_model_id, gguf_file=self.q6_k_nemotron_model_id)
|
||||
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
|
||||
out = model.generate(text, max_new_tokens=10)
|
||||
|
||||
EXPECTED_TEXT = "'Hello. hotmail.com.'"
|
||||
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
|
||||
|
||||
def test_tokenization_xnli(self):
|
||||
import tqdm
|
||||
from datasets import load_dataset
|
||||
|
Loading…
Reference in New Issue
Block a user