mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Rename Phi-3 rope scaling type (#31436)
* renamed phi3 rope_scaling type * fixed trailing whitespaces * fixed test * added warning * fixed format
This commit is contained in:
parent
bab32d6fe9
commit
034b477847
@ -78,7 +78,7 @@ class Phi3Config(PretrainedConfig):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`dict`, *optional*):
|
||||
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
|
||||
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
|
||||
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
|
||||
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
|
||||
divided by the number of attention heads divided by 2.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
@ -155,6 +155,7 @@ class Phi3Config(PretrainedConfig):
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_adjustment()
|
||||
self._rope_scaling_validation()
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
@ -166,6 +167,19 @@ class Phi3Config(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_adjustment(self):
|
||||
"""
|
||||
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
|
||||
# For backward compatibility if previous version used "su" or "yarn"
|
||||
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
|
||||
self.rope_scaling["type"] = "longrope"
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
@ -181,8 +195,8 @@ class Phi3Config(PretrainedConfig):
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
|
||||
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
|
||||
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
|
||||
raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
|
||||
if not (
|
||||
isinstance(rope_scaling_short_factor, list)
|
||||
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
|
||||
|
@ -16,6 +16,7 @@
|
||||
"""PyTorch Phi-3 model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -106,6 +107,11 @@ class Phi3RotaryEmbedding(nn.Module):
|
||||
|
||||
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
def __init__(self, dim, config, device=None):
|
||||
warnings.warn(
|
||||
"The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
|
||||
" use Phi3LongRoPEScaledRotaryEmbedding instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
||||
|
||||
self.short_factor = config.rope_scaling["short_factor"]
|
||||
@ -119,13 +125,10 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
||||
else:
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
||||
|
||||
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
||||
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
||||
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
@ -133,13 +136,11 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
|
||||
cos = emb.cos() * scaling_factor
|
||||
sin = emb.sin() * scaling_factor
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
@ -147,6 +148,10 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
|
||||
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
def __init__(self, dim, config, device=None):
|
||||
warnings.warn(
|
||||
"The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
|
||||
FutureWarning,
|
||||
)
|
||||
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
||||
|
||||
self.short_factor = config.rope_scaling["short_factor"]
|
||||
@ -186,6 +191,47 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
||||
def __init__(self, dim, config, device=None):
|
||||
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
|
||||
|
||||
self.short_factor = config.rope_scaling["short_factor"]
|
||||
self.long_factor = config.rope_scaling["long_factor"]
|
||||
self.original_max_position_embeddings = config.original_max_position_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
if seq_len > self.original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
||||
else:
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
||||
|
||||
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
|
||||
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
|
||||
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||
if scale <= 1.0:
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||
|
||||
cos = emb.cos() * scaling_factor
|
||||
sin = emb.sin() * scaling_factor
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
@ -300,10 +346,8 @@ class Phi3Attention(nn.Module):
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
if scaling_type == "su":
|
||||
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
|
||||
elif scaling_type == "yarn":
|
||||
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
|
||||
if scaling_type == "longrope":
|
||||
self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
|
@ -362,7 +362,7 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@parameterized.expand([("su",), ("yarn",)])
|
||||
@parameterized.expand([("longrope",)])
|
||||
def test_model_rope_scaling_from_config(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
|
Loading…
Reference in New Issue
Block a user