mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Refactor: Use Llama RoPE implementation for Falcon (#26933)
* Use Llama RoPE implementation for Falcon + Add copy functionalities * Use standard cache format for Falcon * Simplify apply_rotary_pos_emb, copy from Llama * Remove unnecessary cache conversion test We don't need to convert any caches anymore! * Resolve copy complaint
This commit is contained in:
parent
e9a6c72b5e
commit
05ea7b79e6
@ -71,12 +71,43 @@ class FalconLinear(nn.Linear):
|
||||
return hidden_states + self.bias
|
||||
|
||||
|
||||
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
@ -90,138 +121,88 @@ def _get_unpad_data(attention_mask):
|
||||
)
|
||||
|
||||
|
||||
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
|
||||
class FalconRotaryEmbedding(nn.Module):
|
||||
"""Implementation of RotaryEmbedding from GPT-NeoX.
|
||||
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
|
||||
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
|
||||
"""
|
||||
|
||||
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
self.base = base
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.head_dim = head_dim
|
||||
self.seq_len_cached = -1
|
||||
self.cos_cached: torch.Tensor | None = None
|
||||
self.sin_cached: torch.Tensor | None = None
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=device).to(dtype)
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
self.cos_cached = emb.cos()
|
||||
self.sin_cached = emb.sin()
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
||||
def cos_sin(
|
||||
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
|
||||
) -> torch.Tensor:
|
||||
total_length = seq_len + past_key_values_length
|
||||
if total_length > self.seq_len_cached:
|
||||
self._set_cos_sin_cache(total_length, device, dtype)
|
||||
|
||||
# the cached tensors need to update their devices (for example, after we change the model's device)
|
||||
self.cos_cached = self.cos_cached.to(device)
|
||||
self.sin_cached = self.sin_cached.to(device)
|
||||
|
||||
# Gather cos, sin at the designated position ids
|
||||
cos = self.cos_cached[position_ids] # [bs, seq_len, dim]
|
||||
sin = self.sin_cached[position_ids] # [bs, seq_len, dim]
|
||||
return cos, sin
|
||||
|
||||
def forward(self, query, key, past_key_values_length, position_ids):
|
||||
_, seq_len, _ = query.shape
|
||||
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)
|
||||
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
|
||||
# avoid unnecessary repeat_interleave operations.
|
||||
query_expansion_factor = int(query.shape[0] / cos.shape[0])
|
||||
if query_expansion_factor > 1:
|
||||
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
|
||||
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
|
||||
else:
|
||||
query_cos, query_sin = cos, sin
|
||||
|
||||
key_expansion_factor = int(key.shape[0] / cos.shape[0])
|
||||
if key_expansion_factor > 1:
|
||||
if key_expansion_factor != query_expansion_factor:
|
||||
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
|
||||
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
|
||||
else:
|
||||
key_cos, key_sin = query_cos, query_sin
|
||||
else:
|
||||
key_cos, key_sin = cos, sin
|
||||
|
||||
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin)
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
|
||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_dim, base, max_position_embeddings)
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
t = torch.arange(seq_len, device=device).to(dtype)
|
||||
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()
|
||||
self.sin_cached = emb.sin()
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
|
||||
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
"""
|
||||
FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||
"""
|
||||
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(head_dim, base, max_position_embeddings)
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.seq_len_cached = seq_len
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
# This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.head_dim / (self.head_dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(seq_len, device=device).to(dtype)
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
||||
|
||||
if dtype in [torch.float16, torch.bfloat16]:
|
||||
emb = emb.float()
|
||||
|
||||
self.cos_cached = emb.cos()
|
||||
self.sin_cached = emb.sin()
|
||||
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
def _prepare_4d_attention_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
|
||||
@ -293,6 +274,8 @@ class FalconAttention(nn.Module):
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.split_size = self.hidden_size
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
@ -301,7 +284,8 @@ class FalconAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k)
|
||||
if config.rotary:
|
||||
self._init_rope()
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
@ -319,33 +303,33 @@ class FalconAttention(nn.Module):
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
rotary_emb = FalconRotaryEmbedding(
|
||||
self.rotary_emb = FalconRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=self.config.rope_theta,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = FalconLinearScalingRotaryEmbedding(
|
||||
self.rotary_emb = FalconLinearScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=self.config.rope_theta,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
|
||||
self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
base=self.config.rope_theta,
|
||||
max_position_embeddings=self.config.max_position_embeddings,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
return rotary_emb
|
||||
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -428,35 +412,31 @@ class FalconAttention(nn.Module):
|
||||
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
|
||||
kv_seq_len = key_layer.shape[-2]
|
||||
if layer_past is not None:
|
||||
kv_seq_len += layer_past[0].shape[-2]
|
||||
if alibi is None:
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
# - key: [batch_size, self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size, self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=-2)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=-2)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
kv_length = key_layer.shape[-2]
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
|
||||
if alibi is None:
|
||||
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
|
||||
# TODO: deprecate this once we add FA2 support in Falcon
|
||||
@ -467,15 +447,15 @@ class FalconAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
||||
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
|
||||
)
|
||||
attention_scores = None
|
||||
else:
|
||||
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
|
||||
attention_scores = query_layer @ key_layer.transpose(-1, -2)
|
||||
attention_scores /= math.sqrt(self.head_dim)
|
||||
|
||||
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
|
||||
attn_output = attention_scores @ value_layer_
|
||||
attn_output = attention_scores @ value_layer
|
||||
|
||||
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
@ -489,7 +469,7 @@ class FalconAttention(nn.Module):
|
||||
return output_tensor, present
|
||||
|
||||
else:
|
||||
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
|
||||
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
@ -516,7 +496,7 @@ class FalconAttention(nn.Module):
|
||||
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
|
||||
context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1)
|
||||
|
||||
# change view [batch_size, q_length, num_heads * head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
@ -563,37 +543,27 @@ class FalconFlashAttention2(FalconAttention):
|
||||
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
|
||||
kv_seq_len = key_layer.shape[-2]
|
||||
if layer_past is not None:
|
||||
kv_seq_len += layer_past[0].shape[-2]
|
||||
if alibi is None:
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
|
||||
|
||||
if layer_past is not None and use_cache:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, kv_seq_length, _ = key_layer.shape
|
||||
|
||||
torch_dtype = query_layer.dtype
|
||||
# - key: [batch_size, self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size, self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=-2)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=-2)
|
||||
|
||||
past_key_value = (key_layer, value_layer) if use_cache else None
|
||||
|
||||
query_layer = (
|
||||
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
|
||||
)
|
||||
key_layer = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
|
||||
value_layer = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
|
||||
|
||||
if alibi is not None:
|
||||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
||||
|
||||
@ -940,42 +910,6 @@ class FalconPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
@staticmethod
|
||||
def _convert_cache_to_standard_format(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
|
||||
# [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
|
||||
# Note that don't want to use self.num_attention_heads because the number of heads may vary depending
|
||||
# on whether we use multi_query attention.
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
|
||||
layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_rw_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
|
||||
layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
@ -1046,8 +980,6 @@ class FalconModel(FalconPreTrainedModel):
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_key_values = self._convert_to_rw_cache(past_key_values)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -1073,7 +1005,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
|
||||
past_key_values_length = past_key_values[0][0].shape[-2]
|
||||
|
||||
if self.use_alibi:
|
||||
mask = (
|
||||
@ -1143,9 +1075,6 @@ class FalconModel(FalconPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if presents is not None:
|
||||
presents = self._convert_cache_to_standard_format(presents, batch_size)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
|
@ -340,24 +340,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
def test_cache_conversions(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = input_dict["input_ids"]
|
||||
model = FalconForCausalLM(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, use_cache=True)
|
||||
batch_size = input_ids.shape[0]
|
||||
rw_cache = model._convert_to_rw_cache(result.past_key_values)
|
||||
standard_cache = model._convert_cache_to_standard_format(rw_cache, batch_size)
|
||||
for layer in range(len(rw_cache)):
|
||||
for tensor_idx in range(2):
|
||||
self.assertTrue(rw_cache[layer][tensor_idx].ndim == 3)
|
||||
self.assertTrue(result.past_key_values[layer][tensor_idx].ndim == 4)
|
||||
self.assertTrue(
|
||||
torch.all(result.past_key_values[layer][tensor_idx] == standard_cache[layer][tensor_idx])
|
||||
)
|
||||
|
||||
def test_falcon_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
|
Loading…
Reference in New Issue
Block a user