Refactor: Use Llama RoPE implementation for Falcon (#26933)

* Use Llama RoPE implementation for Falcon

+ Add copy functionalities

* Use standard cache format for Falcon

* Simplify apply_rotary_pos_emb, copy from Llama

* Remove unnecessary cache conversion test

We don't need to convert any caches anymore!

* Resolve copy complaint
This commit is contained in:
Tom Aarsen 2023-11-03 12:05:55 +01:00 committed by GitHub
parent e9a6c72b5e
commit 05ea7b79e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 128 additions and 217 deletions

View File

@ -71,12 +71,43 @@ class FalconLinear(nn.Linear):
return hidden_states + self.bias
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@ -90,138 +121,88 @@ def _get_unpad_data(attention_mask):
)
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon
class FalconRotaryEmbedding(nn.Module):
"""Implementation of RotaryEmbedding from GPT-NeoX.
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
"""
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (self.base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = -1
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device).to(dtype)
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
def cos_sin(
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
) -> torch.Tensor:
total_length = seq_len + past_key_values_length
if total_length > self.seq_len_cached:
self._set_cos_sin_cache(total_length, device, dtype)
# the cached tensors need to update their devices (for example, after we change the model's device)
self.cos_cached = self.cos_cached.to(device)
self.sin_cached = self.sin_cached.to(device)
# Gather cos, sin at the designated position ids
cos = self.cos_cached[position_ids] # [bs, seq_len, dim]
sin = self.sin_cached[position_ids] # [bs, seq_len, dim]
return cos, sin
def forward(self, query, key, past_key_values_length, position_ids):
_, seq_len, _ = query.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
# avoid unnecessary repeat_interleave operations.
query_expansion_factor = int(query.shape[0] / cos.shape[0])
if query_expansion_factor > 1:
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
else:
query_cos, query_sin = cos, sin
key_expansion_factor = int(key.shape[0] / cos.shape[0])
if key_expansion_factor > 1:
if key_expansion_factor != query_expansion_factor:
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
else:
key_cos, key_sin = query_cos, query_sin
else:
key_cos, key_sin = cos, sin
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(head_dim, base, max_position_embeddings)
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device).to(dtype)
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
"""
FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, head_dim: int, base=10000, max_position_embeddings=2048, scaling_factor=1.0):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(head_dim, base, max_position_embeddings)
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.seq_len_cached = seq_len
self.max_seq_len_cached = seq_len
# This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.head_dim / (self.head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.head_dim, 2).float().to(device) / self.head_dim))
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len, device=device).to(dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def _prepare_4d_attention_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
@ -293,6 +274,8 @@ class FalconAttention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if self.head_dim * self.num_heads != self.hidden_size:
@ -301,7 +284,8 @@ class FalconAttention(nn.Module):
f" {self.num_heads})."
)
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k)
if config.rotary:
self._init_rope()
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
@ -319,33 +303,33 @@ class FalconAttention(nn.Module):
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Falcon
def _init_rope(self):
if self.config.rope_scaling is None:
rotary_emb = FalconRotaryEmbedding(
self.rotary_emb = FalconRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
max_position_embeddings=self.config.max_position_embeddings,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = FalconLinearScalingRotaryEmbedding(
self.rotary_emb = FalconLinearScalingRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
max_position_embeddings=self.config.max_position_embeddings,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
self.rotary_emb = FalconDynamicNTKScalingRotaryEmbedding(
self.head_dim,
base=self.config.rope_theta,
max_position_embeddings=self.config.max_position_embeddings,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
@ -428,35 +412,31 @@ class FalconAttention(nn.Module):
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
kv_seq_len += layer_past[0].shape[-2]
if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# - key: [batch_size, self.num_heads, kv_length, head_dim]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=-2)
value_layer = torch.cat((past_value, value_layer), dim=-2)
_, kv_length, _ = key_layer.shape
kv_length = key_layer.shape[-2]
if use_cache:
present = (key_layer, value_layer)
else:
present = None
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
if alibi is None:
if hasattr(F, "scaled_dot_product_attention") and not output_attentions:
# TODO: deprecate this once we add FA2 support in Falcon
@ -467,15 +447,15 @@ class FalconAttention(nn.Module):
)
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
)
attention_scores = None
else:
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
attention_scores = query_layer @ key_layer.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
attn_output = attention_scores @ value_layer_
attn_output = attention_scores @ value_layer
attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
@ -489,7 +469,7 @@ class FalconAttention(nn.Module):
return output_tensor, present
else:
matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
matmul_result = query_layer @ key_layer.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
@ -516,7 +496,7 @@ class FalconAttention(nn.Module):
attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
context_layer = (attention_probs_reshaped @ value_layer).flatten(0, 1)
# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)
@ -563,37 +543,27 @@ class FalconFlashAttention2(FalconAttention):
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
kv_seq_len = key_layer.shape[-2]
if layer_past is not None:
kv_seq_len += layer_past[0].shape[-2]
if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
if layer_past is not None and use_cache:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
_, kv_seq_length, _ = key_layer.shape
torch_dtype = query_layer.dtype
# - key: [batch_size, self.num_heads, kv_length, head_dim]
# - value: [batch_size, self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=-2)
value_layer = torch.cat((past_value, value_layer), dim=-2)
past_key_value = (key_layer, value_layer) if use_cache else None
query_layer = (
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
)
key_layer = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
value_layer = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim).transpose(1, 2).to(torch_dtype)
if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
@ -940,42 +910,6 @@ class FalconPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@staticmethod
def _convert_cache_to_standard_format(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
# [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
# Note that don't want to use self.num_attention_heads because the number of heads may vary depending
# on whether we use multi_query attention.
num_heads = batch_size_times_num_heads // batch_size
return tuple(
(
layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_rw_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
)
for layer_past in past_key_value
)
@add_start_docstrings(
"The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
@ -1046,8 +980,6 @@ class FalconModel(FalconPreTrainedModel):
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
else:
past_key_values = self._convert_to_rw_cache(past_key_values)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@ -1073,7 +1005,7 @@ class FalconModel(FalconPreTrainedModel):
# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
past_key_values_length = past_key_values[0][0].shape[-2]
if self.use_alibi:
mask = (
@ -1143,9 +1075,6 @@ class FalconModel(FalconPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if presents is not None:
presents = self._convert_cache_to_standard_format(presents, batch_size)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

View File

@ -340,24 +340,6 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_cache_conversions(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = input_dict["input_ids"]
model = FalconForCausalLM(config)
model.to(torch_device)
model.eval()
result = model(input_ids, use_cache=True)
batch_size = input_ids.shape[0]
rw_cache = model._convert_to_rw_cache(result.past_key_values)
standard_cache = model._convert_cache_to_standard_format(rw_cache, batch_size)
for layer in range(len(rw_cache)):
for tensor_idx in range(2):
self.assertTrue(rw_cache[layer][tensor_idx].ndim == 3)
self.assertTrue(result.past_key_values[layer][tensor_idx].ndim == 4)
self.assertTrue(
torch.all(result.past_key_values[layer][tensor_idx] == standard_cache[layer][tensor_idx])
)
def test_falcon_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3