mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Falcon: Add RoPE scaling (#25878)
This commit is contained in:
parent
024acd271b
commit
53e2fd785b
@ -154,14 +154,14 @@ class OpenLlamaConfig(PretrainedConfig):
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -72,6 +72,19 @@ class FalconConfig(PretrainedConfig):
|
||||
instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
|
||||
bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias on Linear layers.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
|
||||
Falcon models with RoPE support up to 2048 tokens.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
bos_token_id (`int`, *optional*, defaults to 11):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 11):
|
||||
@ -111,6 +124,9 @@ class FalconConfig(PretrainedConfig):
|
||||
multi_query=True,
|
||||
parallel_attn=True,
|
||||
bias=False,
|
||||
max_position_embeddings=2048,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
bos_token_id=11,
|
||||
eos_token_id=11,
|
||||
**kwargs,
|
||||
@ -135,6 +151,10 @@ class FalconConfig(PretrainedConfig):
|
||||
self.multi_query = multi_query # Ignored when new_decoder_architecture is True
|
||||
self.parallel_attn = parallel_attn
|
||||
self.bias = bias
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
@ -145,3 +165,27 @@ class FalconConfig(PretrainedConfig):
|
||||
@property
|
||||
def rotary(self):
|
||||
return not self.alibi
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if self.rotary:
|
||||
raise ValueError("`rope_scaling` is not supported when `alibi` is `True`.")
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -71,32 +71,36 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
|
||||
"""
|
||||
|
||||
def __init__(self, head_dim: int, base=10000):
|
||||
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.base = base
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.head_dim = head_dim
|
||||
self.seq_len_cached = -1
|
||||
self.cos_cached: torch.Tensor | None = None
|
||||
self.sin_cached: torch.Tensor | None = None
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
||||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
||||
total_length = seq_len + past_key_values_length
|
||||
if total_length > self.seq_len_cached:
|
||||
self.seq_len_cached = total_length
|
||||
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
||||
self._set_cos_sin_cache(total_length, device, dtype)
|
||||
return (
|
||||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
||||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
||||
@ -108,6 +112,66 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
||||
|
||||
|
||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_dim, base, max_position_embeddings)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
||||
|
||||
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
"""
|
||||
FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
|
||||
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_dim, base, max_position_embeddings)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
|
||||
# This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.head_dim / (self.head_dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()[None, :, :]
|
||||
self.sin_cached = emb.sin()[None, :, :]
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
@ -191,6 +255,7 @@ class FalconAttention(nn.Module):
|
||||
def __init__(self, config: FalconConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
@ -203,7 +268,7 @@ class FalconAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
|
||||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
@ -221,6 +286,34 @@ class FalconAttention(nn.Module):
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
rotary_emb = FalconRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=self.config.rope_theta,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = FalconLinearScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=self.config.rope_theta,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=self.config.rope_theta,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
return rotary_emb
|
||||
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
|
||||
|
@ -163,14 +163,14 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -165,14 +165,14 @@ class LlamaConfig(PretrainedConfig):
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -17,7 +17,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, FalconConfig, is_torch_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -410,6 +412,37 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
|
||||
)
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
def test_model_rope_scaling(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = FalconModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = FalconModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
|
||||
@require_torch
|
||||
class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user