Add-helium (#35669)

* Add the helium model.

* Add a missing helium.

* And add another missing helium.

* Use float for the rmsnorm mul.

* Add the Helium tokenizer converter.

* Add the pad token as suggested by Arthur.

* Update the RMSNorm + some other tweaks.

* Fix more rebase issues.

* fix copies and style

* fixes and add helium.md

* add missing tests

* udpate the backlink

* oups

* style

* update init, and expected results

* small fixes

* match test outputs

* style fixup, fix doc builder

* add dummies and we should be good to go!z

* update sdpa and fa2 documentation

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Arthur 2025-01-13 18:41:15 +01:00 committed by GitHub
parent a3f82328ed
commit c23a1c1932
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1826 additions and 0 deletions

View File

@ -452,6 +452,8 @@
title: Granite
- local: model_doc/granitemoe
title: GraniteMoe
- local: model_doc/helium
title: Helium
- local: model_doc/herbert
title: HerBERT
- local: model_doc/ibert

View File

@ -173,6 +173,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Graphormer](model_doc/graphormer) | ✅ | ❌ | ❌ |
| [Grounding DINO](model_doc/grounding-dino) | ✅ | ❌ | ❌ |
| [GroupViT](model_doc/groupvit) | ✅ | ✅ | ❌ |
| [Helium](model_doc/helium) | ✅ | ❌ | ❌ |
| [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ |
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |

View File

@ -0,0 +1,158 @@
<!--Copyright 2024 Kyutai and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Helium
## Overview
Helium was proposed in [Announcing Helium-1 Preview](https://kyutai.org/2025/01/13/helium.html) by the Kyutai Team.
Helium-1 preview is a lightweight language model with 2B parameters, targeting edge and mobile devices.
It supports the following languages: English, French, German, Italian, Portuguese, Spanish.
- **Developed by:** Kyutai
- **Model type:** Large Language Model
- **Language(s) (NLP):** English, French, German, Italian, Portuguese, Spanish
- **License:** CC-BY 4.0
## Evaluation
<!-- This section describes the evaluation protocols and provides the results. -->
#### Testing Data
<!-- This should link to a Dataset Card if possible. -->
The model was evaluated on MMLU, TriviaQA, NaturalQuestions, ARC Easy & Challenge, Open Book QA, Common Sense QA,
Physical Interaction QA, Social Interaction QA, HellaSwag, WinoGrande, Multilingual Knowledge QA, FLORES 200.
#### Metrics
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
We report accuracy on MMLU, ARC, OBQA, CSQA, PIQA, SIQA, HellaSwag, WinoGrande.
We report exact match on TriviaQA, NQ and MKQA.
We report BLEU on FLORES.
### English Results
| Benchmark | Helium-1 Preview | HF SmolLM2 (1.7B) | Gemma-2 (2.6B) | Llama-3.2 (3B) | Qwen2.5 (1.5B) |
|--------------|--------|--------|--------|--------|--------|
| | | | | | |
| MMLU | 51.2 | 50.4 | 53.1 | 56.6 | 61.0 |
| NQ | 17.3 | 15.1 | 17.7 | 22.0 | 13.1 |
| TQA | 47.9 | 45.4 | 49.9 | 53.6 | 35.9 |
| ARC E | 80.9 | 81.8 | 81.1 | 84.6 | 89.7 |
| ARC C | 62.7 | 64.7 | 66.0 | 69.0 | 77.2 |
| OBQA | 63.8 | 61.4 | 64.6 | 68.4 | 73.8 |
| CSQA | 65.6 | 59.0 | 64.4 | 65.4 | 72.4 |
| PIQA | 77.4 | 77.7 | 79.8 | 78.9 | 76.0 |
| SIQA | 64.4 | 57.5 | 61.9 | 63.8 | 68.7 |
| HS | 69.7 | 73.2 | 74.7 | 76.9 | 67.5 |
| WG | 66.5 | 65.6 | 71.2 | 72.0 | 64.8 |
| | | | | | |
| Average | 60.7 | 59.3 | 62.2 | 64.7 | 63.6 |
#### Multilingual Results
| Language | Benchmark | Helium-1 Preview | HF SmolLM2 (1.7B) | Gemma-2 (2.6B) | Llama-3.2 (3B) | Qwen2.5 (1.5B) |
|-----|--------------|--------|--------|--------|--------|--------|
| | | | | | | |
|German| MMLU | 45.6 | 35.3 | 45.0 | 47.5 | 49.5 |
|| ARC C | 56.7 | 38.4 | 54.7 | 58.3 | 60.2 |
|| HS | 53.5 | 33.9 | 53.4 | 53.7 | 42.8 |
|| MKQA | 16.1 | 7.1 | 18.9 | 20.2 | 10.4 |
| | | | | | | |
|Spanish| MMLU | 46.5 | 38.9 | 46.2 | 49.6 | 52.8 |
|| ARC C | 58.3 | 43.2 | 58.8 | 60.0 | 68.1 |
|| HS | 58.6 | 40.8 | 60.5 | 61.1 | 51.4 |
|| MKQA | 16.0 | 7.9 | 18.5 | 20.6 | 10.6 |
## Technical Specifications
### Model Architecture and Objective
| Hyperparameter | Value |
|--------------|--------|
| Layers | 24 |
| Heads | 20 |
| Model dimension | 2560 |
| MLP dimension | 7040 |
| Context size | 4096 |
| Theta RoPE | 100,000 |
Tips:
- This model was contributed by [Laurent Mazare](https://huggingface.co/lmz)
## Usage tips
`Helium` can be found on the [Huggingface Hub](https://huggingface.co/collections/kyutai/helium-1-preview)
In the following, we demonstrate how to use `helium-1-preview` for the inference.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto
>>> model = AutoModelForCausalLM.from_pretrained("helium-1-preview", device_map="auto")
>>> tokenizer = AutoTokenizer.from_pretrained("helium-1-preview")
>>> prompt = "Give me a short introduction to large language model."
>>> messages = [{"role": "user", "content": prompt}]
>>> text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> model_inputs = tokenizer([text], return_tensors="pt").to(device)
>>> generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512, do_sample=True)
>>> generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
>>> response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
```
## HeliumConfig
[[autodoc]] HeliumConfig
## HeliumModel
[[autodoc]] HeliumModel
- forward
## HeliumForCausalLM
[[autodoc]] HeliumForCausalLM
- forward
## HeliumForSequenceClassification
[[autodoc]] HeliumForSequenceClassification
- forward
## HeliumForTokenClassification
[[autodoc]] HeliumForTokenClassification
- forward

View File

@ -109,6 +109,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
@ -324,6 +325,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
* [helium](https://huggingface.co/docs/transformers/main/en/model_doc/heliumtransformers.HeliumModel)
<Tip>

View File

@ -498,6 +498,7 @@ _import_structure = {
"GroupViTTextConfig",
"GroupViTVisionConfig",
],
"models.helium": ["HeliumConfig"],
"models.herbert": ["HerbertTokenizer"],
"models.hiera": ["HieraConfig"],
"models.hubert": ["HubertConfig"],
@ -2506,6 +2507,15 @@ else:
"GroupViTVisionModel",
]
)
_import_structure["models.helium"].extend(
[
"HeliumForCausalLM",
"HeliumForSequenceClassification",
"HeliumForTokenClassification",
"HeliumModel",
"HeliumPreTrainedModel",
]
)
_import_structure["models.hiera"].extend(
[
"HieraBackbone",
@ -5529,6 +5539,7 @@ if TYPE_CHECKING:
GroupViTTextConfig,
GroupViTVisionConfig,
)
from .models.helium import HeliumConfig
from .models.herbert import HerbertTokenizer
from .models.hiera import HieraConfig
from .models.hubert import HubertConfig
@ -7371,6 +7382,13 @@ if TYPE_CHECKING:
GroupViTTextModel,
GroupViTVisionModel,
)
from .models.helium import (
HeliumForCausalLM,
HeliumForSequenceClassification,
HeliumForTokenClassification,
HeliumModel,
HeliumPreTrainedModel,
)
from .models.hiera import (
HieraBackbone,
HieraForImageClassification,

View File

@ -1446,6 +1446,95 @@ class MoshiConverter(SpmConverter):
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
class HeliumConverter(SpmConverter):
handle_byte_fallback = True
def __init__(self, vocab_file=None, *args):
requires_backends(self, "protobuf")
Converter.__init__(self, vocab_file)
model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
with open(vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m
def tokenizer(self, proto):
vocab_scores = self.vocab(proto)
tokenizer = Tokenizer(
Unigram(
vocab_scores,
unk_id=self.unk_id(proto),
byte_fallback=self.handle_byte_fallback,
)
)
# control tokens are special
# user defined symbols are not
# both user and control tokens are AddedTokens
# Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
spm_added_tokens = [
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
for id, p in enumerate(proto.pieces)
if p.type in [3, 4]
]
tokenizer.add_tokens(
[
AddedToken(token, normalized=False, special=special, single_word=True)
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
]
)
tokenizer.add_tokens([AddedToken("\n", normalized=False, special=False)])
tokenizer.enable_padding(pad_token="<pad>", pad_id=3)
return tokenizer
def vocab(self, proto):
vocab = []
for piece in proto.pieces:
if piece.piece == "<0x0A>":
vocab += [("\n", piece.score)]
else:
vocab += [(piece.piece, piece.score)]
return vocab
def unk_id(self, proto):
unk_id = 0
return unk_id
def decoder(self, replacement, add_prefix_space):
sequence = [
decoders.Replace("", " "),
decoders.ByteFallback(),
decoders.Fuse(),
]
sequence += [decoders.Strip(content=" ", left=1)]
return decoders.Sequence(sequence)
def normalizer(self, proto):
return normalizers.Sequence([normalizers.Prepend(" "), normalizers.Replace(r" ", "")])
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Sequence([pre_tokenizers.Split("\n", "contiguous")])
def post_processor(self):
return processors.TemplateProcessing(
single=[
"<s>",
"$A",
],
pair=[
"<s>",
"$A",
"<s>",
"$B",
],
special_tokens=[
("<s>", 1),
],
)
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""

View File

@ -117,6 +117,7 @@ from . import (
granitemoe,
grounding_dino,
groupvit,
helium,
herbert,
hiera,
hubert,

View File

@ -137,6 +137,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("graphormer", "GraphormerConfig"),
("grounding-dino", "GroundingDinoConfig"),
("groupvit", "GroupViTConfig"),
("helium", "HeliumConfig"),
("hiera", "HieraConfig"),
("hubert", "HubertConfig"),
("ibert", "IBertConfig"),
@ -458,6 +459,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("graphormer", "Graphormer"),
("grounding-dino", "Grounding DINO"),
("groupvit", "GroupViT"),
("helium", "Helium"),
("herbert", "HerBERT"),
("hiera", "Hiera"),
("hubert", "Hubert"),

View File

@ -132,6 +132,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("graphormer", "GraphormerModel"),
("grounding-dino", "GroundingDinoModel"),
("groupvit", "GroupViTModel"),
("helium", "HeliumModel"),
("hiera", "HieraModel"),
("hubert", "HubertModel"),
("ibert", "IBertModel"),
@ -517,6 +518,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gptj", "GPTJForCausalLM"),
("granite", "GraniteForCausalLM"),
("granitemoe", "GraniteMoeForCausalLM"),
("helium", "HeliumForCausalLM"),
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
@ -989,6 +991,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gpt_neo", "GPTNeoForSequenceClassification"),
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("helium", "HeliumForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("jamba", "JambaForSequenceClassification"),
("jetmoe", "JetMoeForSequenceClassification"),
@ -1182,6 +1185,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
("gpt_neo", "GPTNeoForTokenClassification"),
("gpt_neox", "GPTNeoXForTokenClassification"),
("helium", "HeliumForTokenClassification"),
("ibert", "IBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),

View File

@ -226,6 +226,7 @@ else:
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
("grounding-dino", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("groupvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("helium", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("herbert", ("HerbertTokenizer", "HerbertTokenizerFast" if is_tokenizers_available() else None)),
("hubert", ("Wav2Vec2CTCTokenizer", None)),
("ibert", ("RobertaTokenizer", "RobertaTokenizerFast" if is_tokenizers_available() else None)),

View File

@ -0,0 +1,27 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_helium import *
from .modeling_helium import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,140 @@
# coding=utf-8
# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
class HeliumConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`HeliumModel`]. It is used to instantiate an Helium
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Helium 2b model.
e.g. [kyutai/helium-2b](https://huggingface.co/kyutai/helium-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 48000):
Vocabulary size of the Helium model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`HeliumModel`]
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 7040):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 20):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 20):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The legacy activation function. It is overwritten by the `hidden_activation`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-08):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 100000.0):
The base period of the RoPE embeddings.
pad_token_id (`int`, *optional*, defaults to 3):
Padding token id.
eos_token_id (`int` | `list`, *optional*, defaults to 2):
End of stream token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
```python
>>> from transformers import HeliumModel, HeliumConfig
>>> # Initializing a Helium 2b style configuration
>>> configuration = HeliumConfig()
>>> # Initializing a model from the Helium 2b style configuration
>>> model = HeliumModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "helium"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=48000,
hidden_size=2560,
intermediate_size=7040,
num_hidden_layers=24,
num_attention_heads=20,
num_key_value_heads=20,
head_dim=128,
hidden_act="silu",
attention_dropout=0.0,
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-8,
use_cache=True,
tie_word_embeddings=False,
rope_theta=100000.0,
pad_token_id=3,
eos_token_id=2,
bos_token_id=1,
attention_bias=False,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
__all__ = ["HeliumConfig"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,171 @@
# coding=utf-8
# Copyright 2024 The Kyutai and HuggingFace Inc. teams. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ...utils import logging
from ..gemma.modeling_gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
)
from ..granite.modeling_granite import (
GraniteAttention,
)
from ..llama.modeling_llama import (
LlamaDecoderLayer,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
)
from .configuration_helium import HeliumConfig
logger = logging.get_logger(__name__)
class HeliumRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class HeliumRotaryEmbedding(LlamaRotaryEmbedding):
pass
class HeliumMLP(LlamaMLP):
pass
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., 0::2]
x2 = x[..., 1::2]
return torch.stack((-x2, x1), dim=-1).flatten(-2)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Interleave them instead of usual shape
cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class HeliumAttention(GraniteAttention):
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.scaling = 1 / math.sqrt(self.head_dim)
class HeliumDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: HeliumConfig, layer_idx: Optional[int] = None):
super().__init__()
self.mlp = HeliumMLP(config)
self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
class HeliumPreTrainedModel(LlamaPreTrainedModel):
pass
class HeliumModel(HeliumPreTrainedModel, LlamaModel):
def __init__(self, config: HeliumConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[HeliumDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = HeliumRotaryEmbedding(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class HeliumForCausalLM(GemmaForCausalLM):
def __init__(self, config: HeliumConfig):
super().__init__(config)
self.model = HeliumModel(config)
self.post_init()
class HeliumForSequenceClassification(GemmaForSequenceClassification):
def __init__(self, config: HeliumConfig):
super().__init__(config)
self.model = HeliumModel(config)
self.post_init()
class HeliumForTokenClassification(GemmaForTokenClassification):
def __init__(self, config: HeliumConfig):
super().__init__(config)
self.model = HeliumModel(config)
self.post_init()
__all__ = [
"HeliumPreTrainedModel",
"HeliumModel",
"HeliumForCausalLM",
"HeliumForSequenceClassification",
"HeliumForTokenClassification",
]

View File

@ -4981,6 +4981,41 @@ class GroupViTVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class HeliumForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HeliumForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HeliumForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HeliumModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HeliumPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HieraBackbone(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@ -0,0 +1,110 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Helium model."""
import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, HeliumConfig, is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester
from ..gemma.test_modeling_gemma import GemmaModelTest, GemmaModelTester
if is_torch_available():
import torch
from transformers import (
HeliumForCausalLM,
HeliumForSequenceClassification,
HeliumForTokenClassification,
HeliumModel,
)
class HeliumModelTester(GemmaModelTester):
if is_torch_available():
config_class = HeliumConfig
model_class = HeliumModel
for_causal_lm_class = HeliumForCausalLM
for_sequence_class = HeliumForSequenceClassification
for_token_class = HeliumForTokenClassification
@require_torch
class HeliumModelTest(GemmaModelTest, unittest.TestCase):
all_model_classes = (
(HeliumModel, HeliumForCausalLM, HeliumForSequenceClassification, HeliumForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (HeliumForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": HeliumModel,
"text-classification": HeliumForSequenceClassification,
"token-classification": HeliumForTokenClassification,
"text-generation": HeliumForCausalLM,
"zero-shot": HeliumForSequenceClassification,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
def setUp(self):
self.model_tester = HeliumModelTester(self)
self.config_tester = ConfigTester(self, config_class=HeliumConfig, hidden_size=37)
@slow
# @require_torch_gpu
class HeliumIntegrationTest(unittest.TestCase):
input_text = ["Hello, today is a great day to"]
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@require_read_token
def test_model_2b(self):
model_id = "kyutai/helium-1-preview"
EXPECTED_TEXTS = [
"Hello, today is a great day to start a new project. I have been working on a new project for a while now and I have"
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, revision="refs/pr/1"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision="refs/pr/1")
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)