mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
parent
2af199c42b
commit
083e13b7c4
@ -150,7 +150,7 @@ def _compute_dynamic_ntk_parameters(
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
# seq_len: default to max_position_embeddings, e.g. at init time
|
||||
seq_len = seq_len if seq_len is not None else max_position_embeddings
|
||||
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
||||
|
||||
# Compute the inverse frequencies
|
||||
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
||||
@ -210,7 +210,7 @@ def _compute_yarn_parameters(
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
@ -218,6 +218,8 @@ def _compute_yarn_parameters(
|
||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
@ -225,8 +227,11 @@ def _compute_yarn_parameters(
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
||||
|
||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
|
||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||
)
|
||||
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from transformers import LlamaConfig
|
||||
@ -116,5 +117,323 @@ class RopeTest(unittest.TestCase):
|
||||
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
|
||||
torch.testing.assert_close(config_freqs, kwargs_freqs)
|
||||
|
||||
def test_default_rope_numerically(self):
|
||||
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
|
||||
# multiple RoPE strategies will fail.
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
|
||||
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
|
||||
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
|
||||
1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
|
||||
5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
|
||||
2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
|
||||
1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
|
||||
4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
|
||||
1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# TODO(joao): numerical checks for the different RoPE fns
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_scaling, None)
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertEqual(config.rope_theta, 10000.0)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
|
||||
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for default RoPE
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_linear_rope_numerically(self):
|
||||
# This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
|
||||
# frequencies (= the inverse frequencies are scaled **inversely**)
|
||||
config = LlamaConfig()
|
||||
default_rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["linear"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_scaling = {"rope_type": "linear", "factor": factor}
|
||||
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for linear RoPE
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
|
||||
|
||||
def test_dynamic_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.0931e-01, 6.5498e-01, 5.3008e-01, 4.2900e-01, 3.4720e-01,
|
||||
2.8099e-01, 2.2741e-01, 1.8404e-01, 1.4895e-01, 1.2055e-01, 9.7558e-02,
|
||||
7.8955e-02, 6.3899e-02, 5.1714e-02, 4.1853e-02, 3.3872e-02, 2.7413e-02,
|
||||
2.2185e-02, 1.7955e-02, 1.4531e-02, 1.1760e-02, 9.5176e-03, 7.7027e-03,
|
||||
6.2339e-03, 5.0451e-03, 4.0831e-03, 3.3045e-03, 2.6744e-03, 2.1644e-03,
|
||||
1.7517e-03, 1.4176e-03, 1.1473e-03, 9.2852e-04, 7.5146e-04, 6.0817e-04,
|
||||
4.9220e-04, 3.9834e-04, 3.2238e-04, 2.6091e-04, 2.1115e-04, 1.7089e-04,
|
||||
1.3830e-04, 1.1193e-04, 9.0585e-05, 7.3312e-05, 5.9332e-05, 4.8018e-05,
|
||||
3.8861e-05, 3.1451e-05, 2.5453e-05, 2.0600e-05, 1.6672e-05, 1.3492e-05,
|
||||
1.0920e-05, 8.8374e-06, 7.1522e-06, 5.7883e-06, 4.6845e-06, 3.7912e-06,
|
||||
3.0683e-06, 2.4832e-06, 2.0097e-06, 1.6265e-06
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_scaling, None)
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertEqual(config.rope_theta, 10000.0)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
|
||||
# model's original training sequence length
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_scaling = {"rope_type": "dynamic", "factor": factor}
|
||||
inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for dynamic RoPE
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq)
|
||||
|
||||
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
|
||||
# will scale up (i.e., the inverse frequencies will scale down).
|
||||
factor = 10.0
|
||||
config.rope_scaling = {"rope_type": "dynamic", "factor": factor}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=16384)
|
||||
with self.assertRaises(AssertionError): # It is NOT a linear factor
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / factor)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_yarn_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.3479e-02,
|
||||
6.9590e-02, 5.7925e-02, 4.8136e-02, 3.9931e-02, 3.3061e-02, 2.7315e-02,
|
||||
2.2515e-02, 1.8512e-02, 1.5177e-02, 1.2403e-02, 1.0101e-02, 8.1924e-03,
|
||||
6.6143e-03, 5.3120e-03, 4.2400e-03, 3.3599e-03, 2.6396e-03, 2.0520e-03,
|
||||
1.5746e-03, 1.1882e-03, 8.7713e-04, 6.2810e-04, 4.3007e-04, 2.7384e-04,
|
||||
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
|
||||
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
|
||||
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
|
||||
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_scaling, None)
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertEqual(config.rope_theta, 10000.0)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
|
||||
# `0.1 * math.log(factor) + 1.0`
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["yarn"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_scaling = {"rope_type": "yarn", "factor": factor}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 0.1 * math.log(factor) + 1.0)
|
||||
|
||||
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "attention_factor": 0.5}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||
self.assertEqual(attention_scale, 0.5)
|
||||
|
||||
# Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
|
||||
# Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
|
||||
# `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
|
||||
# (note: adds a margin to the test for numerical stability)
|
||||
factor = 10.0
|
||||
margin = 1e-8
|
||||
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 32, "beta_slow": 1}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_bounded_by_factor = [
|
||||
((default_inv_freq[idx] / factor) - margin) <= yarn_inv_freq_value <= (default_inv_freq[idx] + margin)
|
||||
for idx, yarn_inv_freq_value in enumerate(inv_freq)
|
||||
]
|
||||
self.assertTrue(all(is_bounded_by_factor))
|
||||
|
||||
# super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
|
||||
# values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
|
||||
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 1000, "beta_slow": 1}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_interpolating = [
|
||||
yarn_inv_freq_value < (default_inv_freq[idx] + margin) for idx, yarn_inv_freq_value in enumerate(inv_freq)
|
||||
]
|
||||
self.assertFalse(is_interpolating[0])
|
||||
self.assertTrue(all(is_interpolating[1:]))
|
||||
torch.testing.assert_close(inv_freq[-20:], default_inv_freq[-20:] / factor)
|
||||
|
||||
# Check 3: numerical snapshot to avoid regressions
|
||||
config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 32, "beta_slow": 1}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
||||
def test_longrope_rope_numerically(self):
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_scaling, None)
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertEqual(config.rope_theta, 10000.0)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on `factor`
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
short_factor = [2.0] * (dim // 2) # scaling applied when factor == 1.0
|
||||
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when factor > 1.0
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
|
||||
# `math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))`
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["longrope"]
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_scaling = {
|
||||
"rope_type": "longrope",
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)))
|
||||
|
||||
config.rope_scaling = {
|
||||
"rope_type": "longrope",
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
"attention_factor": 0.5,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||
self.assertEqual(attention_scale, 0.5)
|
||||
|
||||
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
|
||||
factor = 1.0
|
||||
config.rope_scaling = {
|
||||
"rope_type": "longrope",
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
|
||||
|
||||
# Check 3: Factor > 1.0 -> long factor is applied to the default frequencies
|
||||
factor = 10.0
|
||||
config.rope_scaling = {
|
||||
"rope_type": "longrope",
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
|
||||
|
||||
def test_llama3_rope_numerically(self):
|
||||
# fmt: off
|
||||
EXPECTED_INV_FREQ = torch.tensor(
|
||||
[
|
||||
1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
|
||||
4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
|
||||
1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
|
||||
7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
|
||||
3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
|
||||
1.3335e-02, 1.0730e-02, 7.7785e-03, 5.6009e-03, 3.9991e-03, 2.8248e-03,
|
||||
1.9675e-03, 1.3449e-03, 8.9549e-04, 5.7363e-04, 3.4539e-04, 2.7384e-04,
|
||||
2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
|
||||
1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
|
||||
4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
|
||||
1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
|
||||
], device=torch_device
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# input sanity checks: if these change, the output will also change
|
||||
config = LlamaConfig()
|
||||
self.assertEqual(config.rope_scaling, None)
|
||||
self.assertEqual(config.hidden_size, 4096)
|
||||
self.assertEqual(config.num_attention_heads, 32)
|
||||
self.assertEqual(config.rope_theta, 10000.0)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
|
||||
# Check 1: `attention_factor` is always 1
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["llama3"]
|
||||
for factor in (2.0, 10.0, 20.0):
|
||||
config.rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 4,
|
||||
}
|
||||
_, attention_scale = rope_fn(config=config, device=torch_device)
|
||||
self.assertEqual(attention_scale, 1.0)
|
||||
|
||||
# Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
|
||||
# `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequences see no change, medium
|
||||
# frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
|
||||
# is considered low, medium, and high frequencies.
|
||||
factor = 10.0
|
||||
config.rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 4,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_bounded_by_factor = [
|
||||
(default_inv_freq[idx] / factor) <= llama3_inv_freq_value <= default_inv_freq[idx]
|
||||
for idx, llama3_inv_freq_value in enumerate(inv_freq)
|
||||
]
|
||||
self.assertTrue(all(is_bounded_by_factor))
|
||||
|
||||
# if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
|
||||
# scaled
|
||||
config.rope_scaling = config.rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 1000,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
is_scaled = [yarn_inv_freq_value < default_inv_freq[idx] for idx, yarn_inv_freq_value in enumerate(inv_freq)]
|
||||
self.assertTrue(all(is_scaled))
|
||||
|
||||
# Check 3: numerical snapshot to avoid regressions
|
||||
config.rope_scaling = {
|
||||
"rope_type": "llama3",
|
||||
"factor": factor,
|
||||
"original_max_position_embeddings": 2048,
|
||||
"low_freq_factor": 1,
|
||||
"high_freq_factor": 4,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
|
||||
|
Loading…
Reference in New Issue
Block a user