mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00

* quick fix * 3 losses * oups * fix * nits * check how it scales for special models * propagate for conditiona detr * propagate * propagate * propagate * fixes * propagate changes * update * fixup * nits * f string * fixes * more fixes * ? * nit * arg annoying f string * nits * grumble * update * nit * refactor * fix fetch tests * nit * nit * Update src/transformers/loss/loss_utils.py Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> * update * nit * fixup * make pass * nits * port code to more models * fixup * ntis * arf * update * update * nits * update * fix * update * nits * fine * agjkfslga.jsdlkgjklas * nits * fix fx? * update * update * styel * fix imports * update * update * fixup to fix the torch fx? --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
1042 lines
46 KiB
Python
1042 lines
46 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
|
|
from ...utils import logging
|
|
from ..llama.modeling_llama import (
|
|
LlamaDecoderLayer,
|
|
LlamaFlashAttention2,
|
|
LlamaForCausalLM,
|
|
LlamaForSequenceClassification,
|
|
LlamaForTokenClassification,
|
|
LlamaModel,
|
|
apply_rotary_pos_emb,
|
|
repeat_kv,
|
|
)
|
|
from ..llama.tokenization_llama import LlamaTokenizer
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from ...tokenization_utils_base import TextInput
|
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
|
|
|
SPIECE_UNDERLINE = "▁"
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class GemmaConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the Gemma-7B.
|
|
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
documentation from [`PretrainedConfig`] for more information.
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 256000):
|
|
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`GemmaModel`]
|
|
hidden_size (`int`, *optional*, defaults to 3072):
|
|
Dimension of the hidden representations.
|
|
intermediate_size (`int`, *optional*, defaults to 24576):
|
|
Dimension of the MLP representations.
|
|
num_hidden_layers (`int`, *optional*, defaults to 28):
|
|
Number of hidden layers in the Transformer decoder.
|
|
num_attention_heads (`int`, *optional*, defaults to 16):
|
|
Number of attention heads for each attention layer in the Transformer decoder.
|
|
num_key_value_heads (`int`, *optional*, defaults to 16):
|
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
|
by meanpooling all the original heads within that group. For more details checkout [this
|
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
|
`num_attention_heads`.
|
|
head_dim (`int`, *optional*, defaults to 256):
|
|
The attention head dimension.
|
|
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
|
The legacy activation function. It is overwritten by the `hidden_activation`.
|
|
hidden_activation (`str` or `function`, *optional*):
|
|
The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
|
|
if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
|
|
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
|
The maximum sequence length that this model might ever be used with.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
|
The epsilon used by the rms normalization layers.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
|
relevant if `config.is_decoder=True`.
|
|
pad_token_id (`int`, *optional*, defaults to 0):
|
|
Padding token id.
|
|
eos_token_id (`int`, *optional*, defaults to 1):
|
|
End of stream token id.
|
|
bos_token_id (`int`, *optional*, defaults to 2):
|
|
Beginning of stream token id.
|
|
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
|
Whether to tie weight embeddings
|
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
|
The base period of the RoPE embeddings.
|
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
|
The dropout ratio for the attention probabilities.
|
|
```python
|
|
>>> from transformers import GemmaModel, GemmaConfig
|
|
>>> # Initializing a Gemma gemma-7b style configuration
|
|
>>> configuration = GemmaConfig()
|
|
>>> # Initializing a model from the gemma-7b style configuration
|
|
>>> model = GemmaModel(configuration)
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "gemma"
|
|
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=256000,
|
|
hidden_size=3072,
|
|
intermediate_size=24576,
|
|
num_hidden_layers=28,
|
|
num_attention_heads=16,
|
|
num_key_value_heads=16,
|
|
head_dim=256,
|
|
hidden_act="gelu_pytorch_tanh",
|
|
hidden_activation=None,
|
|
max_position_embeddings=8192,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-6,
|
|
use_cache=True,
|
|
pad_token_id=0,
|
|
eos_token_id=1,
|
|
bos_token_id=2,
|
|
tie_word_embeddings=True,
|
|
rope_theta=10000.0,
|
|
attention_bias=False,
|
|
attention_dropout=0.0,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.head_dim = head_dim
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.hidden_activation = hidden_activation
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.use_cache = use_cache
|
|
self.rope_theta = rope_theta
|
|
self.attention_bias = attention_bias
|
|
self.attention_dropout = attention_dropout
|
|
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
|
|
"""
|
|
Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
|
|
no padding token in the original model.
|
|
|
|
Args:
|
|
vocab_file (`str`):
|
|
Path to the vocabulary file.
|
|
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
|
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
token instead.
|
|
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
|
|
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
|
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
|
|
The end of sequence token.
|
|
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
|
|
A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
|
|
attention mechanisms or loss computation.
|
|
sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
|
|
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
|
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
|
|
to set:
|
|
|
|
- `enable_sampling`: Enable subword regularization.
|
|
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
|
|
|
|
- `nbest_size = {0,1}`: No sampling is performed.
|
|
- `nbest_size > 1`: samples from the nbest_size results.
|
|
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
|
|
using forward-filtering-and-backward-sampling algorithm.
|
|
|
|
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
|
|
BPE-dropout.
|
|
|
|
add_bos_token (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to add an `bos_token` at the start of sequences.
|
|
add_eos_token (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to add an `eos_token` at the end of sequences.
|
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
|
|
extra spaces.
|
|
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
|
Whether or not the default system prompt for Gemma should be used.
|
|
spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to add spaces between special tokens.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file,
|
|
unk_token="<unk>",
|
|
bos_token="<bos>",
|
|
eos_token="<eos>",
|
|
pad_token="<pad>",
|
|
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
|
add_bos_token=True,
|
|
add_eos_token=False,
|
|
clean_up_tokenization_spaces=False,
|
|
use_default_system_prompt=False,
|
|
spaces_between_special_tokens=False,
|
|
**kwargs,
|
|
):
|
|
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
|
bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
|
|
eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
|
|
unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
|
|
pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
|
|
|
|
self.vocab_file = vocab_file
|
|
self.add_bos_token = add_bos_token
|
|
self.add_eos_token = add_eos_token
|
|
self.use_default_system_prompt = use_default_system_prompt
|
|
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
|
self.sp_model.Load(vocab_file)
|
|
|
|
PreTrainedTokenizer.__init__(
|
|
self,
|
|
bos_token=bos_token,
|
|
eos_token=eos_token,
|
|
unk_token=unk_token,
|
|
pad_token=pad_token,
|
|
add_bos_token=add_bos_token,
|
|
add_eos_token=add_eos_token,
|
|
sp_model_kwargs=sp_model_kwargs,
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
use_default_system_prompt=use_default_system_prompt,
|
|
spaces_between_special_tokens=spaces_between_special_tokens,
|
|
**kwargs,
|
|
)
|
|
|
|
def get_spm_processor(self):
|
|
raise AttributeError("Not needed for Gemma")
|
|
|
|
def unk_token_length(self):
|
|
raise AttributeError("Not needed for Gemma")
|
|
|
|
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
|
"""
|
|
Args:
|
|
text: TextInput
|
|
Simply calls PreTrainedTokenizer's method
|
|
"""
|
|
return PreTrainedTokenizer.tokenize(self, text, **kwargs)
|
|
|
|
def _tokenize(self, text, **kwargs):
|
|
"""
|
|
Args:
|
|
text: TextInput
|
|
Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
|
|
"""
|
|
return self.sp_model.encode(text, out_type=str)
|
|
|
|
def _decode(
|
|
self,
|
|
token_ids: List[int],
|
|
skip_special_tokens: bool = False,
|
|
spaces_between_special_tokens: bool = False,
|
|
**kwargs,
|
|
) -> str:
|
|
sub_texts = []
|
|
current_sub_text = []
|
|
for ids in token_ids:
|
|
if skip_special_tokens and ids in self.all_special_ids:
|
|
continue
|
|
if ids in self._added_tokens_decoder:
|
|
if current_sub_text:
|
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
|
sub_texts.append(self._added_tokens_decoder[ids].content)
|
|
current_sub_text = []
|
|
else:
|
|
current_sub_text.append(ids)
|
|
if current_sub_text:
|
|
sub_texts.append(self.sp_model.decode(current_sub_text))
|
|
|
|
if spaces_between_special_tokens:
|
|
sub_texts = " ".join(sub_texts)
|
|
else:
|
|
sub_texts = "".join(sub_texts)
|
|
|
|
return sub_texts.replace(SPIECE_UNDERLINE, " ")
|
|
|
|
def convert_tokens_to_string(self, tokens):
|
|
"""Converts a sequence of tokens (string) in a single string."""
|
|
current_sub_tokens = []
|
|
out_string = ""
|
|
for token in tokens:
|
|
# make sure that special tokens are not decoded using sentencepiece model
|
|
if token in self._added_tokens_encoder:
|
|
out_string += self.sp_model.decode(current_sub_tokens) + token
|
|
current_sub_tokens = []
|
|
else:
|
|
current_sub_tokens.append(token)
|
|
out_string += self.sp_model.decode(current_sub_tokens)
|
|
return out_string
|
|
|
|
|
|
class GemmaRMSNorm(nn.Module):
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.zeros(dim))
|
|
|
|
def _norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x):
|
|
output = self._norm(x.float())
|
|
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
output = output * (1.0 + self.weight.float())
|
|
return output.type_as(x)
|
|
|
|
def extra_repr(self):
|
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
|
|
|
|
|
ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
|
|
|
|
|
|
class GemmaRotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
|
|
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x, position_ids, seq_len=None):
|
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
self.inv_freq.to(x.device)
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
|
position_ids_expanded = position_ids[:, None, :].float()
|
|
# Force float32 since bfloat16 loses precision on long contexts
|
|
# See https://github.com/huggingface/transformers/pull/29285
|
|
device_type = x.device.type
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
|
with torch.autocast(device_type=device_type, enabled=False):
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
cos = emb.cos()
|
|
sin = emb.sin()
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
|
|
|
|
class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
|
"""GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
|
|
def forward(self, x, position_ids):
|
|
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
|
position_ids = position_ids.float() / self.scaling_factor
|
|
cos, sin = super().forward(x, position_ids)
|
|
return cos, sin
|
|
|
|
|
|
class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
|
|
"""GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
|
|
def forward(self, x, position_ids):
|
|
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
|
seq_len = torch.max(position_ids) + 1
|
|
if seq_len > self.max_position_embeddings:
|
|
base = self.base * (
|
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
|
) ** (self.dim / (self.dim - 2))
|
|
inv_freq = 1.0 / (
|
|
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
|
|
)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
|
|
|
|
cos, sin = super().forward(x, position_ids)
|
|
return cos, sin
|
|
|
|
|
|
class GemmaMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
if config.hidden_activation is None:
|
|
logger.warning_once(
|
|
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
|
|
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
|
|
"`config.hidden_activation` if you want to override this behaviour.\n"
|
|
"See https://github.com/huggingface/transformers/pull/29402 for more details."
|
|
)
|
|
config.hidden_activation = "gelu_pytorch_tanh"
|
|
hidden_activation = config.hidden_activation
|
|
self.act_fn = ACT2FN[hidden_activation]
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
class GemmaAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
if layer_idx is None:
|
|
logger.warning_once(
|
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
|
"when creating this class."
|
|
)
|
|
|
|
self.attention_dropout = config.attention_dropout
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = config.head_dim
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.is_causal = True
|
|
self.scaling = 1 / math.sqrt(config.head_dim)
|
|
|
|
if self.hidden_size % self.num_heads != 0:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
|
self.rotary_emb = GemmaRotaryEmbedding(
|
|
self.head_dim,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
base=self.rope_theta,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
|
|
|
|
if attention_mask is not None: # no matter the length, we just slice it
|
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
attn_weights = attn_weights + causal_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
attn_output = attn_output.view(bsz, q_len, -1)
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
class GemmaSdpaAttention(GemmaAttention):
|
|
"""
|
|
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
`GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
|
SDPA API.
|
|
"""
|
|
|
|
# Adapted from GemmaAttention.forward
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
if output_attentions:
|
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
logger.warning_once(
|
|
"GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
)
|
|
return super().forward(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
causal_mask = attention_mask
|
|
if attention_mask is not None:
|
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
|
|
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
|
is_causal = True if causal_mask is None and q_len > 1 else False
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=causal_mask,
|
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
|
is_causal=is_causal,
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.view(bsz, q_len, -1)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, None, past_key_value
|
|
|
|
|
|
class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
|
|
"""
|
|
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
|
"""
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
if isinstance(past_key_value, StaticCache):
|
|
raise ValueError(
|
|
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
|
|
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
|
|
)
|
|
|
|
output_attentions = False
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
# Flash attention requires the input to have the shape
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
|
# therefore we just need to keep the original shape
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
if past_key_value is not None:
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
|
# to be able to avoid many of these transpose/reshape/view.
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
# in fp32. (GemmaRMSNorm handles it correctly)
|
|
|
|
input_dtype = query_states.dtype
|
|
if input_dtype == torch.float32:
|
|
if torch.is_autocast_enabled():
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
# Handle the case where the model is quantized
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
else:
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
logger.warning_once(
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
f" {target_dtype}."
|
|
)
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
key_states = key_states.to(target_dtype)
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
attn_output = _flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
q_len,
|
|
position_ids=position_ids,
|
|
dropout=dropout_rate,
|
|
sliding_window=getattr(self, "sliding_window", None),
|
|
is_causal=self.is_causal,
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
)
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
GEMMA_ATTENTION_CLASSES = {
|
|
"eager": GemmaAttention,
|
|
"flash_attention_2": GemmaFlashAttention2,
|
|
"sdpa": GemmaSdpaAttention,
|
|
}
|
|
|
|
|
|
class GemmaDecoderLayer(LlamaDecoderLayer):
|
|
def __init__(self, config: GemmaConfig, layer_idx: int):
|
|
super().__init__(config)
|
|
self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
|
self.mlp = GemmaMLP(config)
|
|
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
**kwargs,
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`, *optional*):
|
|
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
|
query_sequence_length, key_sequence_length)` if default attention is used.
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
(see `past_key_values`).
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
Indices depicting the position of the input sequence tokens in the sequence
|
|
kwargs (`dict`, *optional*):
|
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
|
into the model
|
|
"""
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# Fully Connected
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
class GemmaModel(LlamaModel):
|
|
def __init__(self, config: GemmaConfig):
|
|
super().__init__(config)
|
|
self.layers = nn.ModuleList(
|
|
[GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet!
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
|
if self.gradient_checkpointing and self.training and use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
)
|
|
use_cache = False
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
# kept for BC (non `Cache` `past_key_values` inputs)
|
|
return_legacy_cache = False # noqa: F841
|
|
if use_cache and not isinstance(past_key_values, Cache):
|
|
return_legacy_cache = True # noqa: F841
|
|
if past_key_values is None:
|
|
past_key_values = DynamicCache()
|
|
else:
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
logger.warning_once(
|
|
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
|
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
|
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
|
)
|
|
|
|
if cache_position is None:
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
cache_position = torch.arange(
|
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
)
|
|
|
|
if position_ids is None:
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = self._update_causal_mask(
|
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
|
)
|
|
|
|
# embed positions
|
|
hidden_states = inputs_embeds
|
|
|
|
# normalized
|
|
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
|
# See https://github.com/huggingface/transformers/pull/29402
|
|
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
|
|
hidden_states = hidden_states * normalizer
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = None
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
decoder_layer.__call__,
|
|
hidden_states,
|
|
causal_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
output_attentions,
|
|
use_cache,
|
|
cache_position,
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = next_decoder_cache if use_cache else None
|
|
if return_legacy_cache:
|
|
next_cache = next_cache.to_legacy_cache()
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
# Example where we ony modify the docstring and call super
|
|
class GemmaForCausalLM(LlamaForCausalLM):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = GemmaModel(config)
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
num_logits_to_keep: int = 0,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
r"""
|
|
```python
|
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
|
|
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
|
|
>>> prompt = "What is your favorite condiment?"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"What is your favorite condiment?"
|
|
```"""
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
cache_position=cache_position,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
loss = self.loss_function(logits, labels, self.vocab_size)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
class GemmaForSequenceClassification(LlamaForSequenceClassification):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = GemmaModel(config)
|
|
self.post_init()
|
|
|
|
|
|
class GemmaForTokenClassification(LlamaForTokenClassification):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = GemmaModel(config)
|
|
self.post_init()
|
|
|
|
|
|
__all__ = [
|
|
"GemmaConfig",
|
|
"GemmaTokenizer",
|
|
"GemmaModel",
|
|
"GemmaForCausalLM",
|
|
"GemmaForSequenceClassification",
|
|
"GemmaForTokenClassification",
|
|
]
|