mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Llama ROPE
] Fix torch export but also slow downs in forward (#29198)
* remove control flow
* update gptneox
* update ....
* nits
* Actually let's just break. Otherwise we are silently failing which imo is not optimal
* version BC
* fix tests
* fix eager causal
* nit
* add a test
* style
* nits
* nits
* more nits for the test
* update and fix
* make sure cuda graphs are not skipped
* read token is needed for meta llama
* update!
* fiixup
* compile test should be slow
* fix thet fix copies
* stle 🫠
This commit is contained in:
parent
7c87f3577e
commit
8a8a0a4ae0
@ -563,10 +563,11 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
|
||||
# TODO @gante bring compatibility back
|
||||
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
@ -586,7 +587,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
|
||||
# TODO @gante no longer copied from
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
@ -92,54 +92,55 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
super().__init__()
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
|
||||
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
|
||||
|
||||
@property
|
||||
def sin_cached(self):
|
||||
logger.warning_once(
|
||||
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead."
|
||||
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
||||
)
|
||||
return self._sin_cached
|
||||
|
||||
@property
|
||||
def cos_cached(self):
|
||||
logger.warning_once(
|
||||
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead."
|
||||
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
||||
)
|
||||
return self._cos_cached
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
if seq_len is not None:
|
||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
|
||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
|
||||
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=x.dtype)
|
||||
sin = emb.sin().to(dtype=x.dtype)
|
||||
# backwards compatibility
|
||||
self._cos_cached = cos
|
||||
self._sin_cached = sin
|
||||
return cos, sin
|
||||
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
|
||||
|
||||
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||
position_ids = position_ids.float() / self.scaling_factor
|
||||
@ -150,10 +151,6 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
@ -367,6 +364,7 @@ class LlamaAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
@ -20,10 +20,12 @@ import unittest
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
@ -595,6 +597,56 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
def test_compile_static_cache(self):
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
||||
logits = model(
|
||||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
|
||||
)[0]
|
||||
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
||||
return new_token
|
||||
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
model._setup_cache(StaticCache, 2, max_cache_len=4096)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
|
||||
|
||||
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
||||
generated_ids[:, seq_length] = next_token[:, 0]
|
||||
|
||||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
with CaptureLogger(logging.get_logger(__name__)) as cl:
|
||||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
|
||||
self.assertNotIn("skipping cudagraphs due to", cl.out)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
cache_position += 1
|
||||
|
||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
|
||||
@require_torch
|
||||
class CodeLlamaIntegrationTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user