mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Falcon: batched generation (#26137)
This commit is contained in:
parent
95a904104e
commit
a796f7eea6
@ -67,6 +67,7 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
|
||||
class FalconRotaryEmbedding(nn.Module):
|
||||
"""Implementation of RotaryEmbedding from GPT-NeoX.
|
||||
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
|
||||
@ -99,19 +100,40 @@ class FalconRotaryEmbedding(nn.Module):
|
||||
self.cos_cached = self.cos_cached.type(dtype)
|
||||
self.sin_cached = self.sin_cached.type(dtype)
|
||||
|
||||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
||||
def cos_sin(
|
||||
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
|
||||
) -> torch.Tensor:
|
||||
total_length = seq_len + past_key_values_length
|
||||
if total_length > self.seq_len_cached:
|
||||
self._set_cos_sin_cache(total_length, device, dtype)
|
||||
return (
|
||||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
||||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
||||
)
|
||||
# Gather cos, sin at the designated position ids
|
||||
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||
return cos, sin
|
||||
|
||||
def forward(self, query, key, past_key_values_length=0):
|
||||
batch, seq_len, head_dim = query.shape
|
||||
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
||||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
||||
def forward(self, query, key, past_key_values_length, position_ids):
|
||||
_, seq_len, _ = query.shape
|
||||
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)
|
||||
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
|
||||
# avoid unnecessary repeat_interleave operations.
|
||||
query_expansion_factor = int(query.shape[0] / cos.shape[0])
|
||||
if query_expansion_factor > 1:
|
||||
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
|
||||
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
|
||||
else:
|
||||
query_cos, query_sin = cos, sin
|
||||
|
||||
key_expansion_factor = int(key.shape[0] / cos.shape[0])
|
||||
if key_expansion_factor > 1:
|
||||
if key_expansion_factor != query_expansion_factor:
|
||||
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
|
||||
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
|
||||
else:
|
||||
key_cos, key_sin = query_cos, query_sin
|
||||
else:
|
||||
key_cos, key_sin = cos, sin
|
||||
|
||||
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin)
|
||||
|
||||
|
||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||
@ -270,7 +292,7 @@ class FalconAttention(nn.Module):
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)
|
||||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
@ -378,6 +400,7 @@ class FalconAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
@ -399,7 +422,7 @@ class FalconAttention(nn.Module):
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
@ -415,7 +438,8 @@ class FalconAttention(nn.Module):
|
||||
else:
|
||||
present = None
|
||||
|
||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
float_min = torch.finfo(query_layer.dtype).min
|
||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype)
|
||||
|
||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
@ -536,6 +560,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
@ -554,6 +579,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
attention_layernorm_out,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
@ -632,6 +658,11 @@ FALCON_INPUTS_DOCSTRING = r"""
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
@ -836,6 +867,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
@ -892,6 +924,14 @@ class FalconModel(FalconPreTrainedModel):
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
else:
|
||||
alibi = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
@ -922,6 +962,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
@ -929,6 +970,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -988,13 +1030,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
|
||||
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
@ -1011,6 +1063,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
@ -1032,6 +1085,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
|
@ -19,8 +19,16 @@ import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
FalconConfig,
|
||||
is_torch_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.testing_utils import CaptureLogger, require_bitsandbytes, require_torch, slow, tooslow, torch_device
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
||||
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
def test_batched_generation(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b",
|
||||
device_map="auto",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
test_text = "A sequence: 1, 2" # should generate the rest of the sequence
|
||||
|
||||
unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0")
|
||||
unpadded_inputs.pop("token_type_ids")
|
||||
unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20)
|
||||
unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True)
|
||||
|
||||
dummy_text = "This is a longer text " * 2 # forces left-padding on `test_text`
|
||||
padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0")
|
||||
padded_inputs.pop("token_type_ids")
|
||||
padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20)
|
||||
padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True)
|
||||
|
||||
expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
|
||||
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
||||
self.assertEqual(unpadded_gen_text[0], expected_output)
|
||||
self.assertEqual(padded_gen_text[0], expected_output)
|
||||
|
||||
|
||||
# TODO Lysandre: Remove this in version v4.34
|
||||
class FalconOverrideTest(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user