mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Granite language models (#31502)
* first commit * drop tokenizer * drop tokenizer * drop tokenizer * drop convert * granite * drop tokenization test * mup * fix * reformat * reformat * reformat * fix docs * stop checking for checkpoint * update support * attention multiplier * update model * tiny drop * saibo drop * skip test * fix test * fix test * drop * drop useless imports * update docs * drop flash function * copied from * drop pretraining tp * drop pretraining tp * drop pretraining tp * drop unused import * drop code path * change name * softmax scale * head dim * drop legacy cache * rename params * cleanup * fix copies * comments * add back legacy cache * multipliers * multipliers * multipliers * text fix * fix copies * merge * multipliers * attention multiplier * drop unused imports * fix * fix * fix * move rope? * Update src/transformers/models/granite/configuration_granite.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * Update src/transformers/models/granite/modeling_granite.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix * fix * fix * fix * fix-copies * torch rmsnorm * add authors * change model path * fix * test * drop static cache test * uupdate readme * drop non-causal * readme * drop useless imports * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/model_doc/granite.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
7591ca5bc5
commit
c35d2ccf5a
@ -412,6 +412,8 @@
|
||||
title: GPTSAN Japanese
|
||||
- local: model_doc/gpt-sw3
|
||||
title: GPTSw3
|
||||
- local: model_doc/granite
|
||||
title: Granite
|
||||
- local: model_doc/herbert
|
||||
title: HerBERT
|
||||
- local: model_doc/ibert
|
||||
|
@ -158,6 +158,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| [GPT-Sw3](model_doc/gpt-sw3) | ✅ | ✅ | ✅ |
|
||||
| [GPTBigCode](model_doc/gpt_bigcode) | ✅ | ❌ | ❌ |
|
||||
| [GPTSAN-japanese](model_doc/gptsan-japanese) | ✅ | ❌ | ❌ |
|
||||
| [Granite](model_doc/granite) | ✅ | ❌ | ❌ |
|
||||
| [Graphormer](model_doc/graphormer) | ✅ | ❌ | ❌ |
|
||||
| [Grounding DINO](model_doc/grounding-dino) | ✅ | ❌ | ❌ |
|
||||
| [GroupViT](model_doc/groupvit) | ✅ | ✅ | ❌ |
|
||||
|
74
docs/source/en/model_doc/granite.md
Normal file
74
docs/source/en/model_doc/granite.md
Normal file
@ -0,0 +1,74 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Granite
|
||||
|
||||
## Overview
|
||||
|
||||
The Granite model was proposed in [Power Scheduler: A Batch Size and Token Number Agnostic Learning Rate Scheduler](https://arxiv.org/abs/2408.13359) by Yikang Shen, Matthew Stallone, Mayank Mishra, Gaoyuan Zhang, Shawn Tan, Aditya Prasad, Adriana Meza Soria, David D. Cox and Rameswar Panda.
|
||||
|
||||
PowerLM-3B is a 3B state-of-the-art small language model trained with the Power learning rate scheduler. It is trained on a wide range of open-source and synthetic datasets with permissive licenses. PowerLM-3B has shown promising results compared to other models in the size categories across various benchmarks, including natural language multi-choices, code generation, and math reasoning.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Finding the optimal learning rate for language model pretraining is a challenging task.
|
||||
This is not only because there is a complicated correlation between learning rate, batch size, number of training tokens, model size, and other hyperparameters but also because it is prohibitively expensive to perform a hyperparameter search for large language models with Billions or Trillions of parameters. Recent studies propose using small proxy models and small corpus to perform hyperparameter searches and transposing the optimal parameters to large models and large corpus. While the zero-shot transferability is theoretically and empirically proven for model size related hyperparameters, like depth and width, the zero-shot transfer from small corpus to large corpus is underexplored.
|
||||
In this paper, we study the correlation between optimal learning rate, batch size, and number of training tokens for the recently proposed WSD scheduler. After thousands of small experiments, we found a power-law relationship between variables and demonstrated its transferability across model sizes. Based on the observation, we propose a new learning rate scheduler, Power scheduler, that is agnostic about the number of training tokens and batch size. The experiment shows that combining the Power scheduler with Maximum Update Parameterization (\mup) can consistently achieve impressive performance with one set of hyperparameters regardless of the number of training tokens, batch size, model size, and even model architecture. Our 3B dense and MoE models trained with the Power scheduler achieve comparable performance as state-of-the-art small language models.
|
||||
We [open source](https://huggingface.co/collections/ibm/power-lm-66be64ae647ddf11b9808000) these pretrained models.*
|
||||
|
||||
Tips:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_path = "ibm/PowerLM-3b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# drop device_map if running on CPU
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
model.eval()
|
||||
|
||||
# change input text as desired
|
||||
prompt = "Write a code to find the maximum value in a list of numbers."
|
||||
|
||||
# tokenize the text
|
||||
input_tokens = tokenizer(prompt, return_tensors="pt")
|
||||
# generate output tokens
|
||||
output = model.generate(**input_tokens, max_new_tokens=100)
|
||||
# decode output tokens into text
|
||||
output = tokenizer.batch_decode(output)
|
||||
# loop over the batch to print, in this example the batch size is 1
|
||||
for i in output:
|
||||
print(i)
|
||||
```
|
||||
|
||||
This model was contributed by [mayank-mishra](https://huggingface.co/mayank-mishra).
|
||||
|
||||
|
||||
## GraniteConfig
|
||||
|
||||
[[autodoc]] GraniteConfig
|
||||
|
||||
## GraniteModel
|
||||
|
||||
[[autodoc]] GraniteModel
|
||||
- forward
|
||||
|
||||
## GraniteForCausalLM
|
||||
|
||||
[[autodoc]] GraniteForCausalLM
|
||||
- forward
|
@ -51,6 +51,7 @@ FlashAttention-2 is currently supported for the following architectures:
|
||||
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
|
||||
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
||||
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
|
||||
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
|
||||
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
|
||||
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
|
||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||
@ -215,6 +216,7 @@ For now, Transformers supports SDPA inference and training for the following arc
|
||||
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
|
||||
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
|
||||
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
|
||||
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
|
||||
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
|
||||
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
|
||||
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
|
||||
|
@ -461,6 +461,7 @@ _import_structure = {
|
||||
"models.gpt_neox_japanese": ["GPTNeoXJapaneseConfig"],
|
||||
"models.gpt_sw3": [],
|
||||
"models.gptj": ["GPTJConfig"],
|
||||
"models.granite": ["GraniteConfig"],
|
||||
"models.grounding_dino": [
|
||||
"GroundingDinoConfig",
|
||||
"GroundingDinoProcessor",
|
||||
@ -2325,6 +2326,13 @@ else:
|
||||
"GPTJPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.granite"].extend(
|
||||
[
|
||||
"GraniteForCausalLM",
|
||||
"GraniteModel",
|
||||
"GranitePreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.grounding_dino"].extend(
|
||||
[
|
||||
"GroundingDinoForObjectDetection",
|
||||
@ -5216,6 +5224,7 @@ if TYPE_CHECKING:
|
||||
GPTNeoXJapaneseConfig,
|
||||
)
|
||||
from .models.gptj import GPTJConfig
|
||||
from .models.granite import GraniteConfig
|
||||
from .models.grounding_dino import (
|
||||
GroundingDinoConfig,
|
||||
GroundingDinoProcessor,
|
||||
@ -6938,6 +6947,11 @@ if TYPE_CHECKING:
|
||||
GPTJModel,
|
||||
GPTJPreTrainedModel,
|
||||
)
|
||||
from .models.granite import (
|
||||
GraniteForCausalLM,
|
||||
GraniteModel,
|
||||
GranitePreTrainedModel,
|
||||
)
|
||||
from .models.grounding_dino import (
|
||||
GroundingDinoForObjectDetection,
|
||||
GroundingDinoModel,
|
||||
|
@ -105,6 +105,7 @@ from . import (
|
||||
gpt_neox_japanese,
|
||||
gpt_sw3,
|
||||
gptj,
|
||||
granite,
|
||||
grounding_dino,
|
||||
groupvit,
|
||||
herbert,
|
||||
|
@ -122,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseConfig"),
|
||||
("gptj", "GPTJConfig"),
|
||||
("gptsan-japanese", "GPTSanJapaneseConfig"),
|
||||
("granite", "GraniteConfig"),
|
||||
("graphormer", "GraphormerConfig"),
|
||||
("grounding-dino", "GroundingDinoConfig"),
|
||||
("groupvit", "GroupViTConfig"),
|
||||
@ -411,6 +412,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("gpt_neox_japanese", "GPT NeoX Japanese"),
|
||||
("gptj", "GPT-J"),
|
||||
("gptsan-japanese", "GPTSAN-japanese"),
|
||||
("granite", "Granite"),
|
||||
("graphormer", "Graphormer"),
|
||||
("grounding-dino", "Grounding DINO"),
|
||||
("groupvit", "GroupViT"),
|
||||
|
@ -119,6 +119,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
|
||||
("gptj", "GPTJModel"),
|
||||
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
||||
("granite", "GraniteModel"),
|
||||
("graphormer", "GraphormerModel"),
|
||||
("grounding-dino", "GroundingDinoModel"),
|
||||
("groupvit", "GroupViTModel"),
|
||||
@ -479,6 +480,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
|
||||
("gptj", "GPTJForCausalLM"),
|
||||
("granite", "GraniteForCausalLM"),
|
||||
("jamba", "JambaForCausalLM"),
|
||||
("jetmoe", "JetMoeForCausalLM"),
|
||||
("llama", "LlamaForCausalLM"),
|
||||
|
57
src/transformers/models/granite/__init__.py
Normal file
57
src/transformers/models/granite/__init__.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_granite": ["GraniteConfig"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_granite"] = [
|
||||
"GraniteForCausalLM",
|
||||
"GraniteModel",
|
||||
"GranitePreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_granite import GraniteConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_granite import (
|
||||
GraniteForCausalLM,
|
||||
GraniteModel,
|
||||
GranitePreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
179
src/transformers/models/granite/configuration_granite.py
Normal file
179
src/transformers/models/granite/configuration_granite.py
Normal file
@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Granite model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GraniteConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GraniteModel`]. It is used to instantiate an Granite
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Granite-3B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the Granite model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GraniteModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier
|
||||
logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits
|
||||
residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier
|
||||
attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier
|
||||
|
||||
```python
|
||||
>>> from transformers import GraniteModel, GraniteConfig
|
||||
|
||||
>>> # Initializing a Granite granite-3b style configuration
|
||||
>>> configuration = GraniteConfig()
|
||||
|
||||
>>> # Initializing a model from the granite-7b style configuration
|
||||
>>> model = GraniteModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "granite"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
embedding_multiplier=1.0,
|
||||
logits_scaling=1.0,
|
||||
residual_multiplier=1.0,
|
||||
attention_multiplier=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
|
||||
self.embedding_multiplier = embedding_multiplier
|
||||
self.logits_scaling = logits_scaling
|
||||
self.residual_multiplier = residual_multiplier
|
||||
self.attention_multiplier = attention_multiplier
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rope_config_validation(self)
|
1182
src/transformers/models/granite/modeling_granite.py
Normal file
1182
src/transformers/models/granite/modeling_granite.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -24,7 +24,7 @@ from torch import nn
|
||||
from .utils import is_torch_xla_available, logging
|
||||
|
||||
|
||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
|
||||
ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.RMSNorm]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -4649,6 +4649,27 @@ class GPTJPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GraniteForCausalLM(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GraniteModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GranitePreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GroundingDinoForObjectDetection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
0
tests/models/granite/__init__.py
Normal file
0
tests/models/granite/__init__.py
Normal file
615
tests/models/granite/test_modeling_granite.py
Normal file
615
tests/models/granite/test_modeling_granite.py
Normal file
@ -0,0 +1,615 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Granite model."""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, GraniteConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
GraniteForCausalLM,
|
||||
GraniteModel,
|
||||
)
|
||||
from transformers.models.granite.modeling_granite import (
|
||||
GraniteRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class GraniteModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=False,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
pad_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.pad_token_id = pad_token_id
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def get_config(self):
|
||||
return GraniteConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = GraniteModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_model_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = GraniteModel(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
|
||||
def create_and_check_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = GraniteForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = GraniteForCausalLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
outputs = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||
|
||||
output_from_no_past = model(
|
||||
next_input_ids,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
output_from_past = model(
|
||||
next_tokens,
|
||||
attention_mask=next_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class GraniteModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
GraniteModel,
|
||||
GraniteForCausalLM,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (GraniteForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": GraniteModel,
|
||||
"text-generation": GraniteForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = False
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
model_split_percents = [0.5, 0.7, 0.8]
|
||||
|
||||
# used in `test_torch_compile`
|
||||
_torch_compile_test_ckpt = "ibm/PowerLM-3b"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = GraniteModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=GraniteConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# def test_granite_sequence_classification_model(self):
|
||||
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# config.num_labels = 3
|
||||
# input_ids = input_dict["input_ids"]
|
||||
# attention_mask = input_ids.ne(1).to(torch_device)
|
||||
# sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
# model = GraniteForSequenceClassification(config)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
# result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
# self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# def test_granite_sequence_classification_model_for_single_label(self):
|
||||
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# config.num_labels = 3
|
||||
# config.problem_type = "single_label_classification"
|
||||
# input_ids = input_dict["input_ids"]
|
||||
# attention_mask = input_ids.ne(1).to(torch_device)
|
||||
# sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
|
||||
# model = GraniteForSequenceClassification(config)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
# result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
# self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# def test_granite_sequence_classification_model_for_multi_label(self):
|
||||
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# config.num_labels = 3
|
||||
# config.problem_type = "multi_label_classification"
|
||||
# input_ids = input_dict["input_ids"]
|
||||
# attention_mask = input_ids.ne(1).to(torch_device)
|
||||
# sequence_labels = ids_tensor(
|
||||
# [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
|
||||
# ).to(torch.float)
|
||||
# model = GraniteForSequenceClassification(config)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
# result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
# self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# def test_granite_token_classification_model(self):
|
||||
# config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
# config.num_labels = 3
|
||||
# input_ids = input_dict["input_ids"]
|
||||
# attention_mask = input_ids.ne(1).to(torch_device)
|
||||
# token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
|
||||
# model = GraniteForTokenClassification(config=config)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
# result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
|
||||
# self.assertEqual(
|
||||
# result.logits.shape,
|
||||
# (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
|
||||
# )
|
||||
|
||||
@unittest.skip("Granite buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = GraniteModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = GraniteModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
def test_model_rope_scaling(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
scaling_factor = 10
|
||||
short_input_length = 10
|
||||
long_input_length = int(config.max_position_embeddings * 1.5)
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = GraniteRotaryEmbedding(config=config).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
|
||||
linear_scaling_rope = GraniteRotaryEmbedding(config=config).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
# with scaling_factor (or that `inv_freq` decreases)
|
||||
config.rope_scaling = {"type": "dynamic", "factor": scaling_factor}
|
||||
ntk_scaling_rope = GraniteRotaryEmbedding(config=config).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(ntk_sin_long, original_sin_long)
|
||||
self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all())
|
||||
|
||||
# Sanity check Yarn RoPE scaling
|
||||
# Scaling should be over the entire input
|
||||
config.rope_scaling = {"type": "yarn", "factor": scaling_factor}
|
||||
yarn_scaling_rope = GraniteRotaryEmbedding(config=config).to(torch_device)
|
||||
yarn_cos_short, yarn_sin_short = yarn_scaling_rope(x, position_ids_short)
|
||||
yarn_cos_long, yarn_sin_long = yarn_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(yarn_cos_short, yarn_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(yarn_sin_short, yarn_sin_long[:, :short_input_length, :])
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_short, original_cos_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_cos_long, original_cos_long)
|
||||
with self.assertRaises(AssertionError):
|
||||
torch.testing.assert_close(yarn_sin_long, original_sin_long)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_read_token
|
||||
@slow
|
||||
def test_flash_attn_2_generate_padding_right(self):
|
||||
"""
|
||||
Overwritting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = GraniteForCausalLM.from_pretrained(
|
||||
"ibm/PowerLM-3b",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
|
||||
|
||||
texts = ["hi", "Hello this is a very long sentence"]
|
||||
|
||||
tokenizer.padding_side = "right"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
|
||||
|
||||
output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_native = tokenizer.batch_decode(output_native)
|
||||
|
||||
model = GraniteForCausalLM.from_pretrained(
|
||||
"ibm/PowerLM-3b",
|
||||
load_in_4bit=True,
|
||||
device_map={"": 0},
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_fa_2 = tokenizer.batch_decode(output_fa_2)
|
||||
|
||||
self.assertListEqual(output_native, output_fa_2)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_use_flash_attention_2_true(self):
|
||||
"""
|
||||
NOTE: this is the only test testing that the legacy `use_flash_attention=2` argument still works as intended.
|
||||
"""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = model_class(config)
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
new_model = GraniteForCausalLM.from_pretrained(
|
||||
tmp_dir, use_flash_attention_2=True, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
self.assertTrue(new_model.config._attn_implementation == "flash_attention_2")
|
||||
|
||||
has_flash = False
|
||||
for name, submodule in new_model.named_modules():
|
||||
if "FlashAttention" in submodule.__class__.__name__:
|
||||
has_flash = True
|
||||
break
|
||||
if not has_flash:
|
||||
raise ValueError("The flash model should have flash attention layers")
|
||||
|
||||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
|
||||
@require_torch_sdpa
|
||||
@slow
|
||||
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
|
||||
"""
|
||||
skipping the test since mup is very flaky and gets consistently different outputs
|
||||
"""
|
||||
self.skipTest("skipping the test since mup is very flaky and gets consistently different outputs")
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class GraniteIntegrationTest(unittest.TestCase):
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
cuda_compute_capability_major_version = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if is_torch_available() and torch.cuda.is_available():
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_model_3b_logits_bf16(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = GraniteForCausalLM.from_pretrained(
|
||||
"ibm/PowerLM-3b", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
# Expected mean on dim = -1
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MEAN = torch.tensor([[-1.8799, -3.1269, -2.8297, -2.3755, -2.7364, -2.2389, -2.5914, -2.4154]])
|
||||
|
||||
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
|
||||
|
||||
# slicing logits[0, 0, 0:15]
|
||||
EXPECTED_SLICE = torch.tensor([[4.8125, -2.0156, -2.0156, -2.0000, -2.0000, -2.8438, -2.0156, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000, -2.0000]])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
EXPECTED_SLICE.to(torch_device),
|
||||
out.logits[0, 0, :15],
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_model_3b_logits(self):
|
||||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
|
||||
|
||||
model = GraniteForCausalLM.from_pretrained("ibm/PowerLM-3b", device_map="auto", torch_dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(torch.tensor([input_ids]).to(torch_device))
|
||||
|
||||
# fmt: off
|
||||
# Expected mean on dim = -1
|
||||
EXPECTED_MEAN = torch.tensor([[0.0000, 0.0000, -3.4374, -2.1636, -2.6245, -3.0029, -3.8229, -3.1158]])
|
||||
|
||||
self.assertTrue(torch.allclose(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), atol=1e-2, rtol=1e-2))
|
@ -44,6 +44,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
"LlamaConfig",
|
||||
"GraniteConfig",
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user