mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Compile compatibilty for decoder-only models (#32617)
* squash into one commit * add qwen2-vl for rope standardization * fix mistral compile * fix qwen2-vl * fix-copies
This commit is contained in:
parent
eedd21b9e7
commit
65bb284448
@ -1629,13 +1629,14 @@ class GenerationMixin:
|
||||
|
||||
# Set pad token if unset (and there are conditions to do so)
|
||||
if pad_token_tensor is None and eos_token_tensor is not None:
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
if not is_torchdynamo_compiling():
|
||||
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
||||
pad_token_tensor = eos_token_tensor[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
|
||||
|
||||
# Sanity checks/warnings
|
||||
if self.config.is_encoder_decoder and decoder_start_token_tensor is None:
|
||||
|
@ -326,14 +326,11 @@ class BloomAttention(nn.Module):
|
||||
|
||||
# reshape qkv for further computations
|
||||
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2)
|
||||
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
|
||||
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||
|
||||
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
attention_scores = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=self.beta,
|
||||
@ -341,9 +338,9 @@ class BloomAttention(nn.Module):
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attn_weights = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, :kv_length]
|
||||
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||
@ -356,7 +353,7 @@ class BloomAttention(nn.Module):
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||
@ -496,6 +493,8 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["BloomBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -895,9 +894,25 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
|
||||
# input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
|
||||
# the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
# This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
|
||||
# The only difference is the usage of 2D instead of 4D mask, but the shape will be static
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask is not None:
|
||||
target_length = past_key_values.get_max_length()
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
diff = target_length - seq_length
|
||||
|
||||
new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, new_attn_mask],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
|
@ -77,13 +77,42 @@ class FalconConfig(PretrainedConfig):
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
bos_token_id (`int`, *optional*, defaults to 11):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 11):
|
||||
@ -167,7 +196,6 @@ class FalconConfig(PretrainedConfig):
|
||||
self.ffn_hidden_size = hidden_size * 4
|
||||
else:
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self._rope_scaling_validation()
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
@ -178,26 +206,3 @@ class FalconConfig(PretrainedConfig):
|
||||
@property
|
||||
def rotary(self):
|
||||
return not self.alibi
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if self.alibi:
|
||||
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||
|
@ -35,6 +35,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
from ...utils import (
|
||||
@ -133,8 +134,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -142,9 +143,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -155,97 +155,126 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Falcon
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
|
||||
class FalconRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[FalconConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`FalconRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
|
||||
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
|
||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`FalconLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`FalconRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "linear"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
|
||||
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
|
||||
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`FalconDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`FalconRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
||||
"__init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "dynamic"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
@ -324,9 +353,6 @@ class FalconAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
if config.rotary:
|
||||
self._init_rope()
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = self.inv_norm_factor
|
||||
@ -343,32 +369,9 @@ class FalconAttention(nn.Module):
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = FalconRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = FalconLinearScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
# TODO (raushan): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
||||
if config.rotary:
|
||||
self.rotary_emb = FalconRotaryEmbedding(config=self.config)
|
||||
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -438,6 +441,7 @@ class FalconAttention(nn.Module):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||
@ -450,18 +454,18 @@ class FalconAttention(nn.Module):
|
||||
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
kv_seq_len = key_layer.shape[-2]
|
||||
if layer_past is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
|
||||
if alibi is None:
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_layer, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
||||
|
||||
if layer_past is not None:
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
@ -597,6 +601,7 @@ class FalconFlashAttention2(FalconAttention):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||
@ -609,18 +614,18 @@ class FalconFlashAttention2(FalconAttention):
|
||||
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
kv_seq_len = key_layer.shape[-2]
|
||||
if layer_past is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += layer_past.get_seq_length(self.layer_idx)
|
||||
if alibi is None:
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_layer, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
|
||||
|
||||
if layer_past is not None:
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
@ -743,6 +748,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
@ -764,6 +770,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
@ -969,6 +976,8 @@ class FalconModel(FalconPreTrainedModel):
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.rotary_emb = FalconRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
@ -1065,6 +1074,9 @@ class FalconModel(FalconPreTrainedModel):
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
next_decoder_cache = None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -1085,6 +1097,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
@ -1097,6 +1110,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""GPTNeoX model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -74,13 +75,42 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
|
||||
speedup at large scales (e.g. 20B).
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
|
||||
@ -136,7 +166,9 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.rotary_pct = rotary_pct
|
||||
self.partial_rotary_factor = rotary_pct
|
||||
self.rotary_emb_base = rotary_emb_base
|
||||
self.rope_theta = rotary_emb_base
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.classifier_dropout = classifier_dropout
|
||||
@ -147,29 +179,13 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self._rope_scaling_validation()
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||
|
@ -38,8 +38,14 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import get_torch_version, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
|
||||
from ...utils import (
|
||||
get_torch_version,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
from .configuration_gpt_neox import GPTNeoXConfig
|
||||
|
||||
|
||||
@ -151,10 +157,11 @@ class GPTNeoXAttention(nn.Module):
|
||||
)
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
self.rope_theta = config.rotary_emb_base
|
||||
self._init_bias(config.max_position_embeddings)
|
||||
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
||||
self._init_rope()
|
||||
self.rotary_emb = GPTNeoXRotaryEmbedding(config=self.config)
|
||||
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
@ -180,31 +187,6 @@ class GPTNeoXAttention(nn.Module):
|
||||
if device is not None:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = GPTNeoXRotaryEmbedding(
|
||||
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
|
||||
self.rotary_ndims,
|
||||
self.config.max_position_embeddings,
|
||||
base=self.config.rotary_emb_base,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
|
||||
self.rotary_ndims,
|
||||
self.config.max_position_embeddings,
|
||||
base=self.config.rotary_emb_base,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@ -216,10 +198,15 @@ class GPTNeoXAttention(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
# Apply attention-specific projections and rope
|
||||
query, key, value, present = self._attn_projections_and_rope(
|
||||
hidden_states=hidden_states, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
# Compute attention
|
||||
@ -267,6 +254,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
# Compute QKV
|
||||
# Attention heads [batch, seq_len, hidden_size]
|
||||
@ -289,19 +277,17 @@ class GPTNeoXAttention(nn.Module):
|
||||
key_rot = key[..., : self.rotary_ndims]
|
||||
key_pass = key[..., self.rotary_ndims :]
|
||||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
seq_len = key.shape[-2]
|
||||
if layer_past is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
seq_len += layer_past.get_seq_length(self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
query = torch.cat((query, query_pass), dim=-1)
|
||||
key = torch.cat((key, key_pass), dim=-1)
|
||||
|
||||
@ -310,7 +296,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
|
||||
@ -395,6 +381,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
# Apply attention-specific projections and rope
|
||||
query, key, value, present = self._attn_projections_and_rope(
|
||||
@ -403,6 +390,7 @@ class GPTNeoXFlashAttention2(GPTNeoXAttention):
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
query_length = query.shape[-2]
|
||||
@ -496,6 +484,7 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
if output_attentions or head_mask is not None:
|
||||
logger.warning_once(
|
||||
@ -524,6 +513,7 @@ class GPTNeoXSdpaAttention(GPTNeoXAttention):
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
causal_mask = attention_mask
|
||||
@ -570,90 +560,119 @@ def attention_mask_func(attention_scores, ltor_mask):
|
||||
return attention_scores
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoX
|
||||
class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[GPTNeoXConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`GPTNeoXRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len],
|
||||
self.sin_cached[:seq_len],
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
|
||||
# TODO @gante bring compatibility back
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->GPTNeoX
|
||||
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`GPTNeoXLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`GPTNeoXRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "linear"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX
|
||||
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
|
||||
# TODO @gante no longer copied from
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`GPTNeoXDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`GPTNeoXRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
||||
"__init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "dynamic"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@ -663,8 +682,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -672,9 +691,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -685,8 +703,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -734,6 +752,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
layer_past: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
attention_layer_outputs = self.attention(
|
||||
self.input_layernorm(hidden_states),
|
||||
@ -744,6 +763,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
||||
attn_output = self.post_attention_dropout(attn_output)
|
||||
@ -860,6 +880,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
self.emb_dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layers = nn.ModuleList([GPTNeoXLayer(config, i) for i in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.rotary_emb = GPTNeoXRotaryEmbedding(config=config)
|
||||
|
||||
self._attn_implementation = config._attn_implementation
|
||||
|
||||
@ -952,6 +973,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
hidden_states = self.emb_dropout(inputs_embeds)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
next_decoder_cache = None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
@ -972,6 +996,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
None,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
outputs = layer(
|
||||
@ -983,6 +1008,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
@ -1183,7 +1209,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# can't be copied from llama, gpt-neox has emebd_out and not lm_head
|
||||
# can't be copied from llama, gpt-neox has embed_out and not lm_head
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""GPTNeoX Japanese model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -59,6 +60,43 @@ class GPTNeoXJapaneseConfig(PretrainedConfig):
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention.
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
||||
@ -96,6 +134,7 @@ class GPTNeoXJapaneseConfig(PretrainedConfig):
|
||||
use_cache=True,
|
||||
bos_token_id=31996,
|
||||
eos_token_id=31999,
|
||||
rope_scaling=None,
|
||||
attention_dropout=0.1,
|
||||
hidden_dropout=0.0,
|
||||
**kwargs,
|
||||
@ -109,9 +148,17 @@ class GPTNeoXJapaneseConfig(PretrainedConfig):
|
||||
self.intermediate_multiple_size = intermediate_multiple_size
|
||||
self.hidden_act = hidden_act
|
||||
self.rotary_pct = rotary_pct
|
||||
self.partial_rotary_factor = rotary_pct
|
||||
self.rotary_emb_base = rotary_emb_base
|
||||
self.rope_theta = rotary_emb_base
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_dropout = hidden_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch GPTNeoX model."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -22,8 +23,11 @@ from torch import Tensor, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig
|
||||
@ -35,6 +39,60 @@ _CHECKPOINT_FOR_DOC = "abeja/gpt-neox-japanese-2.7b"
|
||||
_CONFIG_FOR_DOC = "GPTNeoXJapaneseConfig"
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@ -45,6 +103,9 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "gpt_neox_japanese"
|
||||
_no_split_modules = ["GPTNeoXJapaneseLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -62,19 +123,24 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class GPTNeoXJapaneseAttention(nn.Module):
|
||||
def __init__(self, config, use_bias=False):
|
||||
def __init__(self, config, use_bias=False, layer_idx=None):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
||||
)
|
||||
self.max_positions = config.max_position_embeddings
|
||||
self.rope_theta = config.rotary_emb_base
|
||||
self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||
self.norm_factor = math.sqrt(self.head_size)
|
||||
|
||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
@ -84,15 +150,16 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
layer_past=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
has_layer_past = layer_past is not None and layer_past[0].numel() > 0
|
||||
|
||||
# Compute QKV
|
||||
# Attention heads [batch, seq_len, hidden_size]
|
||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||
@ -114,24 +181,29 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
key_rot = key[..., : self.rotary_ndims]
|
||||
key_pass = key[..., self.rotary_ndims :]
|
||||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
seq_len = key.shape[-2]
|
||||
offset = 0
|
||||
if has_layer_past:
|
||||
offset = layer_past[0].shape[-2]
|
||||
seq_len += offset
|
||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
query = torch.cat((query, query_pass), dim=-1)
|
||||
key = torch.cat((key, key_pass), dim=-1)
|
||||
|
||||
# Cache QKV values
|
||||
if has_layer_past:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = (key, value) if use_cache else None
|
||||
if layer_past is not None:
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
@ -140,7 +212,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
outputs = (attn_output, layer_past)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
@ -171,24 +243,16 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
# -> [bs, seq_len, hidden_size]
|
||||
return tensor
|
||||
|
||||
def _create_causal_mask(self, key_length, query_length):
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view(
|
||||
1, 1, self.max_positions, self.max_positions
|
||||
)
|
||||
)
|
||||
return causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
# compute causal mask from causal mask buffer
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
causal_mask = self._create_causal_mask(key_length, query_length)
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
attn_scores = torch.zeros(
|
||||
batch_size * num_attention_heads,
|
||||
query_length,
|
||||
@ -196,27 +260,20 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
dtype=query.dtype,
|
||||
device=key.device,
|
||||
)
|
||||
attn_scores = torch.baddbmm(
|
||||
attention_scores = torch.baddbmm(
|
||||
attn_scores,
|
||||
query,
|
||||
key.transpose(1, 2),
|
||||
beta=1.0,
|
||||
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
|
||||
alpha=1.0 / self.norm_factor,
|
||||
)
|
||||
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
||||
|
||||
mask_value = torch.finfo(attn_scores.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
|
||||
causal_mask = causal_mask.to(attn_scores.device)
|
||||
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
|
||||
attention_scores = attention_scores.view(batch_size, num_attention_heads, query_length, -1)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attention_scores = attention_scores + causal_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_scores = attn_scores + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = nn.functional.softmax(attention_scores, dim=-1)
|
||||
attn_weights = self.attention_dropout(attn_weights)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
@ -228,42 +285,92 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
|
||||
class RotaryEmbedding(nn.Module):
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoX->GPTNeoXJapanese
|
||||
class GPTNeoXJapaneseRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[GPTNeoXJapaneseConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`GPTNeoXJapaneseRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len],
|
||||
self.sin_cached[:seq_len],
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
@ -273,9 +380,29 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
||||
cos = cos[..., offset : q.shape[-2] + offset, :]
|
||||
sin = sin[..., offset : q.shape[-2] + offset, :]
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -325,18 +452,23 @@ class GPTNeoXJapaneseLayer(nn.Module):
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# activate bias only last layer
|
||||
self.attention = GPTNeoXJapaneseAttention(config=config, use_bias=layer_number == config.num_hidden_layers - 1)
|
||||
self.attention = GPTNeoXJapaneseAttention(
|
||||
config=config, use_bias=layer_number == config.num_hidden_layers - 1, layer_idx=layer_number
|
||||
)
|
||||
self.mlp = GPTNeoXJapaneseMLP(config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
layer_past=None,
|
||||
output_attentions=False,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
layer_past: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
residual = hidden_states
|
||||
ln_out = self.input_layernorm(hidden_states)
|
||||
@ -347,6 +479,9 @@ class GPTNeoXJapaneseLayer(nn.Module):
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
position_ids=position_ids,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn_output = attention_layer_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attention_layer_outputs[1:]
|
||||
@ -419,6 +554,26 @@ GPT_NEOX_JAPANESE_INPUTS_DOCSTRING = r"""
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||
|
||||
Two formats are allowed:
|
||||
- a [`~cache_utils.Cache`] instance;
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||
cache format.
|
||||
|
||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||
legacy cache format will be returned.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||
of shape `(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
@ -427,6 +582,10 @@ GPT_NEOX_JAPANESE_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -444,6 +603,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
[GPTNeoXJapaneseLayer(config=config, layer_number=i) for i in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.rotary_emb = GPTNeoXJapaneseRotaryEmbedding(config=config)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -460,24 +620,17 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
@ -502,40 +655,35 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_in(input_ids)
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
||||
use_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
use_legacy_cache = True
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
if not self.training:
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
|
||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
|
||||
)
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
if not batch_size > 0:
|
||||
raise ValueError("batch_size has to be defined and > 0")
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
seq_length = inputs_embeds.shape[1]
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -543,29 +691,32 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_in(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
presents = () if use_cache else None
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
next_decoder_cache = None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
||||
for i, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=layer_past,
|
||||
layer_past=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
next_decoder_cache = outputs[1]
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
@ -574,16 +725,87 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""",
|
||||
@ -614,35 +836,22 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
|
||||
only required when the model is used as a decoder in a Sequence to Sequence model.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
||||
Returns:
|
||||
|
||||
@ -668,6 +877,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
outputs = self.gpt_neox_japanese(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
@ -675,6 +885,7 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@ -703,18 +914,76 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
position_ids=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values and past_key_values[0] is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||
else:
|
||||
# The clone here is for the same reason as for `position_ids`.
|
||||
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.embed_out.weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_length(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
|
@ -1018,6 +1018,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of Idefics isn't meant for training from scratch - only
|
||||
|
@ -149,7 +149,7 @@ class LlamaRotaryEmbedding(nn.Module):
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.45"
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
@ -224,7 +224,7 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
|
||||
"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "linear"
|
||||
@ -236,7 +236,7 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use "
|
||||
"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
||||
"__init__)."
|
||||
)
|
||||
@ -353,7 +353,7 @@ class LlamaAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
||||
|
||||
# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)
|
||||
# TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
@ -365,7 +365,7 @@ class LlamaAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
@ -400,7 +400,7 @@ class LlamaAttention(nn.Module):
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
@ -473,7 +473,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
@ -500,7 +500,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
@ -586,7 +586,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
@ -620,7 +620,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
@ -695,7 +695,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
|
@ -871,7 +871,7 @@ class MistralModel(MistralPreTrainedModel):
|
||||
# to infer the attention mask.
|
||||
|
||||
# cache_position must be valid here no matter which cache we use
|
||||
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||
|
||||
|
@ -848,7 +848,8 @@ MIXTRAL_START_DOCSTRING = r"""
|
||||
"The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
|
||||
MIXTRAL_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
|
||||
# copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
|
||||
# TODO (Raushan): bring back copied after compile compatibility
|
||||
class MixtralPreTrainedModel(PreTrainedModel):
|
||||
config_class = MixtralConfig
|
||||
base_model_prefix = "model"
|
||||
|
@ -589,7 +589,7 @@ class NemotronDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
|
@ -222,7 +222,7 @@ class OlmoeRotaryEmbedding(nn.Module):
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.45"
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Persimmon model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -60,13 +61,42 @@ class PersimmonConfig(PretrainedConfig):
|
||||
rope_theta (`float`, *optional*, defaults to 25000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
|
||||
is an experimental feature, subject to breaking API changes in future versions.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
qk_layernorm (`bool`, *optional*, default to `True`):
|
||||
Whether or not to normalize the Queries and Keys after projecting the hidden states
|
||||
hidden_dropout (`float`, *optional*, default to 0.0):
|
||||
@ -128,7 +158,11 @@ class PersimmonConfig(PretrainedConfig):
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self._rope_scaling_validation()
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
@ -137,23 +171,3 @@ class PersimmonConfig(PretrainedConfig):
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_persimmon import PersimmonConfig
|
||||
@ -100,88 +101,119 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
|
||||
class PersimmonRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[PersimmonConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Persimmon
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
|
||||
class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
||||
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`PersimmonLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`PersimmonRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "linear"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Persimmon
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
|
||||
class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
|
||||
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`PersimmonDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`PersimmonRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
||||
"__init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "dynamic"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -192,8 +224,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -201,9 +233,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -214,8 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -253,9 +284,8 @@ class PersimmonAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
@ -275,34 +305,7 @@ class PersimmonAttention(nn.Module):
|
||||
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = PersimmonRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = PersimmonLinearScalingRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = PersimmonDynamicNTKScalingRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
self.rotary_emb = PersimmonRotaryEmbedding(config=self.config)
|
||||
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -329,6 +332,7 @@ class PersimmonAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -347,28 +351,28 @@ class PersimmonAttention(nn.Module):
|
||||
value_states = value_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -379,19 +383,13 @@ class PersimmonAttention(nn.Module):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@ -438,6 +436,7 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -447,7 +446,6 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
||||
`[0, config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
|
||||
cached past key and value projection states
|
||||
@ -457,6 +455,11 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@ -472,6 +475,7 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -522,6 +526,8 @@ class PersimmonPreTrainedModel(PreTrainedModel):
|
||||
_no_split_modules = ["PersimmonDecoderLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -633,6 +639,8 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
)
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.rotary_emb = PersimmonRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -703,6 +711,9 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -722,6 +733,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -732,6 +744,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -16,6 +16,7 @@
|
||||
"""Phi model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -75,13 +76,42 @@ class PhiConfig(PretrainedConfig):
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
|
||||
is an experimental feature, subject to breaking API changes in future versions.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
||||
Percentage of the query and keys which will have rotary embedding.
|
||||
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
||||
@ -156,7 +186,11 @@ class PhiConfig(PretrainedConfig):
|
||||
self.rope_scaling = rope_scaling
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.qk_layernorm = qk_layernorm
|
||||
self._rope_scaling_validation()
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
@ -164,23 +198,3 @@ class PhiConfig(PretrainedConfig):
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||
|
@ -33,6 +33,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
@ -112,88 +113,119 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
|
||||
class PhiRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[PhiConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`PhiRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
|
||||
class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
|
||||
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`PhiLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`PhiRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "linear"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
|
||||
class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
|
||||
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`PhiDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`PhiRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
||||
"__init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "dynamic"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -204,8 +236,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -213,9 +245,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -226,8 +257,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -282,9 +313,8 @@ class PhiAttention(nn.Module):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
@ -307,34 +337,7 @@ class PhiAttention(nn.Module):
|
||||
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
|
||||
)
|
||||
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = PhiRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = PhiLinearScalingRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
self.rotary_emb = PhiRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -345,6 +348,7 @@ class PhiAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -360,28 +364,28 @@ class PhiAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -391,7 +395,7 @@ class PhiAttention(nn.Module):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@ -404,12 +408,6 @@ class PhiAttention(nn.Module):
|
||||
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights += causal_mask
|
||||
@ -462,6 +460,7 @@ class PhiFlashAttention2(PhiAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# PhiFlashAttention2 attention does not support output_attentions
|
||||
@ -485,22 +484,28 @@ class PhiFlashAttention2(PhiAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -510,7 +515,7 @@ class PhiFlashAttention2(PhiAttention):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@ -591,6 +596,7 @@ class PhiSdpaAttention(PhiAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -623,28 +629,28 @@ class PhiSdpaAttention(PhiAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -654,7 +660,7 @@ class PhiSdpaAttention(PhiAttention):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@ -719,6 +725,7 @@ class PhiDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -739,6 +746,9 @@ class PhiDecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
@ -757,6 +767,7 @@ class PhiDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn_outputs = self.resid_dropout(attn_outputs)
|
||||
|
||||
@ -803,6 +814,8 @@ class PhiPreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
_supports_quantized_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -914,6 +927,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
[PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.rotary_emb = PhiRotaryEmbedding(config=config)
|
||||
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||
@ -989,6 +1003,9 @@ class PhiModel(PhiPreTrainedModel):
|
||||
inputs_embeds = self.embed_dropout(inputs_embeds)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -1008,6 +1025,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
use_cache,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -1018,6 +1036,7 @@ class PhiModel(PhiPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Qwen2 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -66,6 +67,43 @@ class Qwen2Config(PretrainedConfig):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
@ -106,6 +144,7 @@ class Qwen2Config(PretrainedConfig):
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
@ -132,7 +171,13 @@ class Qwen2Config(PretrainedConfig):
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -135,41 +136,92 @@ class Qwen2RMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[Qwen2Config] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -180,8 +232,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -189,9 +241,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -202,8 +253,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -259,7 +310,6 @@ class Qwen2Attention(nn.Module):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
@ -274,11 +324,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -289,6 +335,7 @@ class Qwen2Attention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -300,17 +347,17 @@ class Qwen2Attention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
@ -321,13 +368,6 @@ class Qwen2Attention(nn.Module):
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@ -381,6 +421,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -392,28 +433,22 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = (
|
||||
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
|
||||
)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
kv_seq_len = key_states.shape[-2] + cache_position[0]
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
@ -504,7 +539,6 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2
|
||||
class Qwen2SdpaAttention(Qwen2Attention):
|
||||
"""
|
||||
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
@ -522,6 +556,7 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -548,12 +583,17 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
@ -627,6 +667,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -643,6 +684,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
@ -661,6 +705,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -711,6 +756,8 @@ class Qwen2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -822,6 +869,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -893,6 +941,9 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -912,6 +963,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -922,6 +974,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Qwen2MoE model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -66,6 +67,43 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
@ -127,6 +165,7 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
@ -158,7 +197,13 @@ class Qwen2MoeConfig(PretrainedConfig):
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
|
@ -37,6 +37,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -211,41 +212,92 @@ class Qwen2MoeRMSNorm(nn.Module):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2Moe
|
||||
class Qwen2MoeRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[Qwen2MoeConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`Qwen2MoeRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -256,8 +308,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -265,9 +317,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -278,8 +329,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -337,7 +388,6 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.attention_dropout = config.attention_dropout
|
||||
@ -352,12 +402,9 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = Qwen2MoeRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config)
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -367,6 +414,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -378,16 +426,17 @@ class Qwen2MoeAttention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -400,12 +449,6 @@ class Qwen2MoeAttention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@ -460,6 +503,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -471,28 +515,22 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = (
|
||||
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
|
||||
)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
kv_seq_len = key_states.shape[-2] + cache_position[0]
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
@ -583,7 +621,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe
|
||||
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
|
||||
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
|
||||
"""
|
||||
Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
||||
@ -601,6 +639,7 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -627,12 +666,17 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
@ -770,6 +814,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -789,6 +834,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
@ -807,6 +855,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -980,6 +1029,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2MoeRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -1055,6 +1105,9 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -1076,6 +1129,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -1087,6 +1141,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -18,6 +18,7 @@ import os
|
||||
from typing import Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -128,13 +129,42 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
vision_config (`Dict`, *optional*):
|
||||
The config for the visual encoder initialization.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
||||
@ -203,4 +233,13 @@ class Qwen2VLConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default'
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
@ -38,6 +38,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
ModelOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -102,41 +103,92 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
||||
rope_deltas: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding
|
||||
class Qwen2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
class Qwen2VLRotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[Qwen2VLConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
||||
# So we expand the inv_freq to shape (3, ...)
|
||||
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -147,7 +199,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section, unsqueeze_dim=1):
|
||||
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
||||
|
||||
Explanation:
|
||||
@ -179,8 +231,6 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section,
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
mrope_section = mrope_section * 2
|
||||
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||
unsqueeze_dim
|
||||
@ -525,7 +575,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.rotary_emb = Qwen2RotaryEmbedding(
|
||||
self.rotary_emb = Qwen2VLRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
@ -540,6 +590,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -553,16 +604,20 @@ class Qwen2VLAttention(nn.Module):
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
kv_seq_len += cache_position[0] + 1
|
||||
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -627,6 +682,7 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -649,14 +705,19 @@ class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = (
|
||||
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
|
||||
)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -768,6 +829,7 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -797,9 +859,18 @@ class Qwen2VLSdpaAttention(Qwen2VLAttention):
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"]
|
||||
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
@ -874,6 +945,7 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -890,6 +962,9 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
@ -908,6 +983,7 @@ class Qwen2VLDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -1061,6 +1137,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
@ -1123,6 +1200,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -1142,6 +1222,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -1152,6 +1233,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""StableLM model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -71,13 +72,42 @@ class StableLmConfig(PretrainedConfig):
|
||||
rope_theta (`float`, *optional*, defaults to `10000.0`):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This
|
||||
is an experimental feature, subject to breaking API changes in future versions.
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
use_qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the model should use bias for qkv layers.
|
||||
qk_layernorm (`bool`, *optional*, defaults to `False`):
|
||||
@ -155,7 +185,11 @@ class StableLmConfig(PretrainedConfig):
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self._rope_scaling_validation()
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
@ -163,23 +197,3 @@ class StableLmConfig(PretrainedConfig):
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -111,88 +112,119 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm
|
||||
class StableLmRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[StableLmConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`StableLmRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->StableLm
|
||||
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
||||
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`StableLmLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`StableLmRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "linear"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->StableLm
|
||||
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
||||
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
def __init__(self, *args, **kwargs):
|
||||
logger.warning_once(
|
||||
"`StableLmDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
|
||||
"`StableLmRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
|
||||
"__init__)."
|
||||
)
|
||||
kwargs["rope_type"] = "dynamic"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -203,8 +235,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -212,9 +244,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -225,8 +256,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -294,9 +325,8 @@ class StableLmAttention(nn.Module):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
|
||||
self.is_causal = True
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
@ -317,35 +347,7 @@ class StableLmAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self._init_rope()
|
||||
|
||||
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention._init_rope with Persimmon->StableLm
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = StableLmRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = StableLmLinearScalingRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = StableLmDynamicNTKScalingRotaryEmbedding(
|
||||
int(self.partial_rotary_factor * self.head_dim),
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
self.rotary_emb = StableLmRotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -356,6 +358,7 @@ class StableLmAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -371,28 +374,28 @@ class StableLmAttention(nn.Module):
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -403,7 +406,7 @@ class StableLmAttention(nn.Module):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@ -414,12 +417,6 @@ class StableLmAttention(nn.Module):
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights += causal_mask
|
||||
@ -457,6 +454,7 @@ class StableLmSdpaAttention(StableLmAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -487,28 +485,28 @@ class StableLmSdpaAttention(StableLmAttention):
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -519,7 +517,7 @@ class StableLmSdpaAttention(StableLmAttention):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@ -586,6 +584,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# StableLmFlashAttention2 attention does not support output_attentions
|
||||
@ -609,27 +608,27 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
query_states = self.q_layernorm(query_states)
|
||||
key_states = self.k_layernorm(key_states)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# Partial rotary embedding
|
||||
query_rot, query_pass = (
|
||||
query_states[..., : self.rotary_emb.dim],
|
||||
query_states[..., self.rotary_emb.dim :],
|
||||
query_states[..., : self.rotary_ndims],
|
||||
query_states[..., self.rotary_ndims :],
|
||||
)
|
||||
key_rot, key_pass = (
|
||||
key_states[..., : self.rotary_emb.dim],
|
||||
key_states[..., self.rotary_emb.dim :],
|
||||
key_states[..., : self.rotary_ndims],
|
||||
key_states[..., self.rotary_ndims :],
|
||||
)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
||||
|
||||
# [batch_size, seq_length, num_heads, head_dim]
|
||||
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
||||
@ -639,7 +638,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
||||
cache_kwargs = {
|
||||
"sin": sin,
|
||||
"cos": cos,
|
||||
"partial_rotation_size": self.rotary_emb.dim,
|
||||
"partial_rotation_size": self.rotary_ndims,
|
||||
"cache_position": cache_position,
|
||||
}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@ -702,6 +701,7 @@ class StableLmDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@ -722,7 +722,10 @@ class StableLmDecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
@ -738,6 +741,7 @@ class StableLmDecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
|
||||
@ -798,6 +802,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
|
||||
_supports_cache_class = True
|
||||
_supports_sdpa = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -908,6 +913,7 @@ class StableLmModel(StableLmPreTrainedModel):
|
||||
[StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.rotary_emb = StableLmRotaryEmbedding(config=config)
|
||||
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.gradient_checkpointing = False
|
||||
@ -980,6 +986,9 @@ class StableLmModel(StableLmPreTrainedModel):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -999,6 +1008,7 @@ class StableLmModel(StableLmPreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -1009,6 +1019,7 @@ class StableLmModel(StableLmPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Starcoder2 model configuration"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
@ -69,6 +70,43 @@ class Starcoder2Config(PretrainedConfig):
|
||||
The id of the "end-of-sequence" token.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
sliding_window (`int`, *optional*):
|
||||
Sliding window attention window size. If not specified, will default to `None` (no sliding window).
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
@ -113,6 +151,7 @@ class Starcoder2Config(PretrainedConfig):
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
sliding_window=None,
|
||||
attention_dropout=0.0,
|
||||
residual_dropout=0.0,
|
||||
@ -134,9 +173,15 @@ class Starcoder2Config(PretrainedConfig):
|
||||
self.norm_epsilon = norm_epsilon
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_dropout = attention_dropout
|
||||
self.residual_dropout = residual_dropout
|
||||
self.embedding_dropout = embedding_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
bos_token_id=bos_token_id,
|
||||
|
@ -36,6 +36,7 @@ from ...modeling_outputs import (
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@ -112,41 +113,92 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
return causal_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2
|
||||
class Starcoder2RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim=None,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
rope_type="default",
|
||||
config: Optional[Starcoder2Config] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# TODO (joao): remove the `if` below, only used for BC
|
||||
self.rope_kwargs = {}
|
||||
if config is None:
|
||||
logger.warning_once(
|
||||
"`Starcoder2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
||||
"`config` argument. All other arguments will be removed in v4.46"
|
||||
)
|
||||
self.rope_kwargs = {
|
||||
"rope_type": rope_type,
|
||||
"factor": scaling_factor,
|
||||
"dim": dim,
|
||||
"base": base,
|
||||
"max_position_embeddings": max_position_embeddings,
|
||||
}
|
||||
self.rope_type = rope_type
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.original_max_seq_len = max_position_embeddings
|
||||
else:
|
||||
# BC: "rope_type" was originally "type"
|
||||
if config.rope_scaling is not None:
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
def _dynamic_frequency_update(self, position_ids, device):
|
||||
"""
|
||||
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||
1 - growing beyond the cached sequence length (allow scaling)
|
||||
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||
"""
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.max_seq_len_cached: # growth
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||
self.max_seq_len_cached = self.original_max_seq_len
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||
cos = cos * self.attention_scaling
|
||||
sin = sin * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@ -157,8 +209,8 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@ -166,9 +218,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@ -179,8 +230,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
@ -238,7 +289,6 @@ class Starcoder2Attention(nn.Module):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.use_bias = config.use_bias
|
||||
self.is_causal = True
|
||||
@ -255,11 +305,7 @@ class Starcoder2Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
|
||||
|
||||
self.rotary_emb = Starcoder2RotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -270,6 +316,7 @@ class Starcoder2Attention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -281,17 +328,17 @@ class Starcoder2Attention(nn.Module):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
@ -302,13 +349,6 @@ class Starcoder2Attention(nn.Module):
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights += causal_mask
|
||||
@ -362,6 +402,7 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
):
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@ -373,28 +414,22 @@ class Starcoder2FlashAttention2(Starcoder2Attention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
||||
rotary_seq_len = (
|
||||
max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
|
||||
)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
|
||||
kv_seq_len = key_states.shape[-2] + cache_position[0]
|
||||
if (
|
||||
getattr(self.config, "sliding_window", None) is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
@ -495,6 +530,7 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@ -521,12 +557,17 @@ class Starcoder2SdpaAttention(Starcoder2Attention):
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
if position_embeddings is None:
|
||||
logger.warning_once(
|
||||
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||
"removed and `position_embeddings` will be mandatory."
|
||||
)
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
@ -599,6 +640,7 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
@ -615,6 +657,9 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
||||
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
||||
with `head_dim` being the embedding dimension of each attention head.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
@ -633,6 +678,7 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@ -684,6 +730,8 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
@ -796,6 +844,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
)
|
||||
self._attn_implementation = config._attn_implementation
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||
self.rotary_emb = Starcoder2RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -867,6 +916,9 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
hidden_states = inputs_embeds
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
@ -886,6 +938,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -896,6 +949,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
@ -514,6 +514,10 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
self.assertListEqual(generated_text, EXPECTED_GENERATIONS)
|
||||
|
||||
@unittest.skip("Bloom needs a 2D attention for alibi")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class BloomEmbeddingTest(unittest.TestCase):
|
||||
|
@ -461,6 +461,10 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = FalconRotaryEmbedding(
|
||||
@ -468,10 +472,10 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, short_input_length)
|
||||
original_cos_long, original_sin_long = original_rope(x, long_input_length)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
@ -481,14 +485,14 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
@ -499,8 +503,8 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
|
@ -382,6 +382,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = GPTNeoXRotaryEmbedding(
|
||||
@ -389,10 +393,10 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rotary_emb_base,
|
||||
).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, short_input_length)
|
||||
original_cos_long, original_sin_long = original_rope(x, long_input_length)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
@ -402,14 +406,14 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
base=config.rotary_emb_base,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
@ -420,8 +424,8 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
base=config.rotary_emb_base,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
|
@ -20,6 +20,7 @@ from transformers import GPTNeoXJapaneseConfig, is_torch_available
|
||||
from transformers.models.gpt_neox_japanese.tokenization_gpt_neox_japanese import GPTNeoXJapaneseTokenizer
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
@ -56,6 +57,8 @@ class GPTNeoXJapaneseModelTester:
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
bos_token_id=1,
|
||||
eos_token_id=0,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@ -81,6 +84,8 @@ class GPTNeoXJapaneseModelTester:
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.eos_token_id = eos_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@ -112,6 +117,8 @@ class GPTNeoXJapaneseModelTester:
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_decoder(self):
|
||||
@ -189,7 +196,7 @@ class GPTNeoXJapaneseModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTNeoXModelJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class GPTNeoXModelJapaneseTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (GPTNeoXJapaneseModel, GPTNeoXJapaneseForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (GPTNeoXJapaneseForCausalLM,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
@ -257,3 +264,7 @@ class GPTNeoXModelJapaneseTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
|
||||
generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
predicted_outputs += generated_string
|
||||
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
|
||||
|
||||
@unittest.skip("GPTNeoXJapanese applies bias to attention scores")
|
||||
def test_custom_4d_attention_mask(self):
|
||||
pass
|
||||
|
@ -433,6 +433,10 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = PersimmonRotaryEmbedding(
|
||||
@ -440,10 +444,10 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, short_input_length)
|
||||
original_cos_long, original_sin_long = original_rope(x, long_input_length)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
@ -453,14 +457,14 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
@ -471,8 +475,8 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
|
@ -409,6 +409,10 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = PhiRotaryEmbedding(
|
||||
@ -416,10 +420,10 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, short_input_length)
|
||||
original_cos_long, original_sin_long = original_rope(x, long_input_length)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
@ -429,14 +433,14 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
@ -447,8 +451,8 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
|
@ -420,6 +420,10 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
|
||||
# Inputs
|
||||
x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device
|
||||
position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_short = position_ids_short.unsqueeze(0)
|
||||
position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device)
|
||||
position_ids_long = position_ids_long.unsqueeze(0)
|
||||
|
||||
# Sanity check original RoPE
|
||||
original_rope = StableLmRotaryEmbedding(
|
||||
@ -427,10 +431,10 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
).to(torch_device)
|
||||
original_cos_short, original_sin_short = original_rope(x, short_input_length)
|
||||
original_cos_long, original_sin_long = original_rope(x, long_input_length)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :])
|
||||
original_cos_short, original_sin_short = original_rope(x, position_ids_short)
|
||||
original_cos_long, original_sin_long = original_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :])
|
||||
|
||||
# Sanity check linear RoPE scaling
|
||||
# New position "x" should match original position with index "x/scaling_factor"
|
||||
@ -440,14 +444,14 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :])
|
||||
linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short)
|
||||
linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :])
|
||||
torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :])
|
||||
for new_position in range(0, long_input_length, scaling_factor):
|
||||
original_position = int(new_position // scaling_factor)
|
||||
torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :])
|
||||
torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :])
|
||||
torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :])
|
||||
|
||||
# Sanity check Dynamic NTK RoPE scaling
|
||||
# Scaling should only be observed after a long input is fed. We can observe that the frequencies increase
|
||||
@ -458,8 +462,8 @@ class StableLmModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
||||
base=config.rope_theta,
|
||||
scaling_factor=scaling_factor,
|
||||
).to(torch_device)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length)
|
||||
ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short)
|
||||
ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long)
|
||||
torch.testing.assert_close(ntk_cos_short, original_cos_short)
|
||||
torch.testing.assert_close(ntk_sin_short, original_sin_short)
|
||||
with self.assertRaises(AssertionError):
|
||||
|
@ -469,6 +469,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
"RwkvForCausalLM",
|
||||
"XGLMForCausalLM",
|
||||
"GPTNeoXForCausalLM",
|
||||
"GPTNeoXJapaneseForCausalLM",
|
||||
"FuyuForCausalLM",
|
||||
]
|
||||
if (
|
||||
|
@ -4640,7 +4640,7 @@ class ModelTesterMixin:
|
||||
if not model_class._supports_static_cache:
|
||||
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if getattr(config, "sliding_window", 0) > 0:
|
||||
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
|
||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
||||
@ -4689,7 +4689,7 @@ class ModelTesterMixin:
|
||||
self.skipTest(f"{model_class.__name__} does not support cache class")
|
||||
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if getattr(config, "sliding_window", 0) > 0:
|
||||
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
|
||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
||||
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user