Rename Phi-3 rope scaling type (#31436)

* renamed phi3 rope_scaling type

* fixed trailing whitespaces

* fixed test

* added warning

* fixed format
This commit is contained in:
Amit Garg 2024-07-23 03:33:22 -07:00 committed by GitHub
parent bab32d6fe9
commit 034b477847
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 13 deletions

View File

@ -78,7 +78,7 @@ class Phi3Config(PretrainedConfig):
The base period of the RoPE embeddings. The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*): rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2. divided by the number of attention heads divided by 2.
bos_token_id (`int`, *optional*, defaults to 1): bos_token_id (`int`, *optional*, defaults to 1):
@ -155,6 +155,7 @@ class Phi3Config(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self._rope_scaling_adjustment()
self._rope_scaling_validation() self._rope_scaling_validation()
self.sliding_window = sliding_window self.sliding_window = sliding_window
@ -166,6 +167,19 @@ class Phi3Config(PretrainedConfig):
**kwargs, **kwargs,
) )
def _rope_scaling_adjustment(self):
"""
Adjust the `type` of the `rope_scaling` configuration for backward compatibility.
"""
if self.rope_scaling is None:
return
rope_scaling_type = self.rope_scaling.get("type", None)
# For backward compatibility if previous version used "su" or "yarn"
if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]:
self.rope_scaling["type"] = "longrope"
def _rope_scaling_validation(self): def _rope_scaling_validation(self):
""" """
Validate the `rope_scaling` configuration. Validate the `rope_scaling` configuration.
@ -181,8 +195,8 @@ class Phi3Config(PretrainedConfig):
rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]: if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}") raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
if not ( if not (
isinstance(rope_scaling_short_factor, list) isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)

View File

@ -16,6 +16,7 @@
"""PyTorch Phi-3 model.""" """PyTorch Phi-3 model."""
import math import math
import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@ -106,6 +107,11 @@ class Phi3RotaryEmbedding(nn.Module):
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None): def __init__(self, dim, config, device=None):
warnings.warn(
"The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
" use Phi3LongRoPEScaledRotaryEmbedding instead.",
FutureWarning,
)
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = config.rope_scaling["short_factor"] self.short_factor = config.rope_scaling["short_factor"]
@ -119,13 +125,10 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else: else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts # Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285 # See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type device_type = x.device.type
@ -133,13 +136,11 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
with torch.autocast(device_type=device_type, enabled=False): with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0: if scale <= 1.0:
scaling_factor = 1.0 scaling_factor = 1.0
else: else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@ -147,6 +148,10 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None): def __init__(self, dim, config, device=None):
warnings.warn(
"The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
FutureWarning,
)
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = config.rope_scaling["short_factor"] self.short_factor = config.rope_scaling["short_factor"]
@ -186,6 +191,47 @@ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = config.rope_scaling["short_factor"]
self.long_factor = config.rope_scaling["long_factor"]
self.original_max_position_embeddings = config.original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
@ -300,10 +346,8 @@ class Phi3Attention(nn.Module):
) )
else: else:
scaling_type = self.config.rope_scaling["type"] scaling_type = self.config.rope_scaling["type"]
if scaling_type == "su": if scaling_type == "longrope":
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
elif scaling_type == "yarn":
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

View File

@ -362,7 +362,7 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@parameterized.expand([("su",), ("yarn",)]) @parameterized.expand([("longrope",)])
def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling_from_config(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size) short_input = ids_tensor([1, 10], config.vocab_size)