mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Llama/GPTNeoX: add RoPE scaling (#24653)
* add rope_scaling * tmp commit * add gptneox * add tests * GPTNeoX can now handle long inputs, so the pipeline test was wrong * Update src/transformers/models/open_llama/configuration_open_llama.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove ntk * remove redundant validation --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
9342c8fb82
commit
34d9409427
@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
use_parallel_residual (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
|
||||
speedup at large scales (e.g. 20B).
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@ -115,6 +124,7 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
use_parallel_residual=True,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -135,7 +145,32 @@ class GPTNeoXConfig(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -86,6 +86,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
@ -94,18 +95,11 @@ class GPTNeoXAttention(nn.Module):
|
||||
)
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
self._init_bias(config.max_position_embeddings)
|
||||
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
|
||||
)
|
||||
self._init_rope()
|
||||
|
||||
self.register_buffer(
|
||||
"norm_factor",
|
||||
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
||||
@ -113,9 +107,44 @@ class GPTNeoXAttention(nn.Module):
|
||||
)
|
||||
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
def _init_bias(self, max_positions, device=None):
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
if device is not None:
|
||||
self.bias = self.bias.to(device)
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = GPTNeoXRotaryEmbedding(
|
||||
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
|
||||
self.rotary_ndims,
|
||||
self.config.max_position_embeddings,
|
||||
base=self.config.rotary_emb_base,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
|
||||
self.rotary_ndims,
|
||||
self.config.max_position_embeddings,
|
||||
base=self.config.rotary_emb_base,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@ -210,6 +239,9 @@ class GPTNeoXAttention(nn.Module):
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
# dynamically increase the causal mask with the key length, if needed.
|
||||
if key_length > self.bias.shape[-1]:
|
||||
self._init_bias(key_length, device=key.device)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask):
|
||||
return attention_scores
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
class GPTNeoXRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
@ -275,16 +315,54 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
|
||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
||||
|
||||
|
||||
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.cos_cached = emb.cos()[None, None, :, :]
|
||||
self.sin_cached = emb.sin()[None, None, :, :]
|
||||
|
||||
|
||||
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.cos_cached = emb.cos()[None, None, :, :]
|
||||
self.sin_cached = emb.sin()[None, None, :, :]
|
||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
|
@ -238,16 +238,24 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding
|
||||
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
@ -256,15 +264,8 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.cos_cached = emb.cos()[None, None, :, :]
|
||||
self.sin_cached = emb.sin()[None, None, :, :]
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
|
||||
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
|
||||
|
||||
|
||||
|
@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig):
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@ -97,6 +106,7 @@ class LlamaConfig(PretrainedConfig):
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -109,6 +119,9 @@ class LlamaConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
@ -116,3 +129,24 @@ class LlamaConfig(PretrainedConfig):
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -91,36 +91,84 @@ class LlamaRMSNorm(nn.Module):
|
||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
dtype = torch.get_default_dtype()
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
|
||||
|
||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@ -176,7 +224,24 @@ class LlamaAttention(nn.Module):
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
@ -67,6 +67,15 @@ class OpenLlamaConfig(PretrainedConfig):
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
|
||||
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||
these scaling strategies behave:
|
||||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||
experimental feature, subject to breaking API changes in future versions.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
@ -104,6 +113,7 @@ class OpenLlamaConfig(PretrainedConfig):
|
||||
attention_dropout_prob=0.1,
|
||||
use_stable_embedding=True,
|
||||
shared_input_output_embedding=True,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
@ -123,6 +133,9 @@ class OpenLlamaConfig(PretrainedConfig):
|
||||
self.attention_dropout_prob = attention_dropout_prob
|
||||
self.use_stable_embedding = use_stable_embedding
|
||||
self.shared_input_output_embedding = shared_input_output_embedding
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
@ -130,3 +143,25 @@ class OpenLlamaConfig(PretrainedConfig):
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
|
||||
|
@ -102,36 +102,86 @@ class OpenLlamaRMSNorm(nn.Module):
|
||||
class OpenLlamaRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
dtype = torch.get_default_dtype()
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
|
||||
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
|
||||
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
|
||||
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
@ -190,7 +240,27 @@ class OpenLlamaAttention(nn.Module):
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
self.rotary_emb = OpenLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
||||
self._init_rope()
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->OpenLlama
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = OpenLlamaRotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=self.max_position_embeddings
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
@ -17,7 +17,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -49,7 +51,7 @@ class GPTNeoXModelTester:
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
@ -298,6 +300,37 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
def test_model_rope_scaling(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = GPTNeoXModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = GPTNeoXModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||
|
@ -17,7 +17,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -332,3 +334,34 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
@unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
def test_model_rope_scaling(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = LlamaModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = LlamaModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
@ -17,7 +17,9 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import OpenLlamaConfig, is_torch_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import OpenLlamaConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -335,3 +337,34 @@ class OpenLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
@unittest.skip("Open-Llama buffers include complex numbers, which breaks this test")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
def test_model_rope_scaling(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
original_model = OpenLlamaModel(config)
|
||||
original_model.to(torch_device)
|
||||
original_model.eval()
|
||||
original_short_output = original_model(short_input).last_hidden_state
|
||||
original_long_output = original_model(long_input).last_hidden_state
|
||||
|
||||
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
||||
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
||||
scaled_model = OpenLlamaModel(config)
|
||||
scaled_model.to(torch_device)
|
||||
scaled_model.eval()
|
||||
scaled_short_output = scaled_model(short_input).last_hidden_state
|
||||
scaled_long_output = scaled_model(long_input).last_hidden_state
|
||||
|
||||
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
||||
# maximum sequence length, so the outputs for the short input should match.
|
||||
if scaling_type == "dynamic":
|
||||
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
||||
|
||||
# The output should be different for long inputs
|
||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||
|
@ -240,7 +240,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
||||
# We don't care about infinite range models.
|
||||
# They already work.
|
||||
# Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly.
|
||||
EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM"]
|
||||
EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM", "GPTNeoXForCausalLM"]
|
||||
if (
|
||||
tokenizer.model_max_length < 10000
|
||||
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
|
||||
|
Loading…
Reference in New Issue
Block a user