mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
🚨 Bloom support for cache class (#31445)
* bloom dynamic cache * bloom follows standard cache format * no skips for bloom anymore * use cache position when possible * clean up * codestyle * Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pr comments * isinstance fix * address comments * make musicgen test happy * [run-slow] bloom --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
44f6fdd74f
commit
f739687684
@ -378,19 +378,7 @@ def _crop_past_key_values(model, past_key_values, max_length):
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
# bloom is special
|
||||
elif "bloom" in model.__class__.__name__.lower() or (
|
||||
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
|
||||
):
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
(
|
||||
past_key_values[idx][0][:, :, :max_length],
|
||||
past_key_values[idx][1][:, :max_length, :],
|
||||
)
|
||||
)
|
||||
past_key_values = tuple(new_past)
|
||||
# gptbigcode is too
|
||||
# gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
|
||||
elif "gptbigcode" in model.__class__.__name__.lower() or (
|
||||
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
|
||||
):
|
||||
@ -402,7 +390,6 @@ def _crop_past_key_values(model, past_key_values, max_length):
|
||||
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
|
||||
elif isinstance(past_key_values, DynamicCache):
|
||||
past_key_values.crop(max_length)
|
||||
|
||||
elif past_key_values is not None:
|
||||
for idx in range(len(past_key_values)):
|
||||
new_past.append(
|
||||
|
@ -639,7 +639,7 @@ class GenerationMixin:
|
||||
|
||||
return input_ids, model_kwargs
|
||||
|
||||
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
|
||||
def _extract_past_from_model_output(self, outputs: ModelOutput):
|
||||
past_key_values = None
|
||||
cache_name = "past_key_values"
|
||||
if "past_key_values" in outputs:
|
||||
@ -652,10 +652,6 @@ class GenerationMixin:
|
||||
past_key_values = outputs.cache_params
|
||||
cache_name = "cache_params"
|
||||
|
||||
# Bloom fix: standardizes the cache format when requested
|
||||
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
|
||||
batch_size = outputs.logits.shape[0]
|
||||
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
|
||||
return cache_name, past_key_values
|
||||
|
||||
def _update_model_kwargs_for_generation(
|
||||
@ -663,13 +659,10 @@ class GenerationMixin:
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
num_new_tokens: int = 1,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values keeping its naming used in model code
|
||||
cache_name, cache = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
cache_name, cache = self._extract_past_from_model_output(outputs)
|
||||
model_kwargs[cache_name] = cache
|
||||
if getattr(outputs, "state", None) is not None:
|
||||
model_kwargs["state"] = outputs.state
|
||||
@ -2558,7 +2551,6 @@ class GenerationMixin:
|
||||
outputs,
|
||||
model_kwargs,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
standardize_cache_format=True,
|
||||
)
|
||||
|
||||
if not sequential:
|
||||
@ -2723,7 +2715,7 @@ class GenerationMixin:
|
||||
next_past_key_values = selected_outputs["past_key_values"]
|
||||
|
||||
else:
|
||||
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
||||
_, next_past_key_values = self._extract_past_from_model_output(outputs)
|
||||
# Do it in-place layer per layer to save memory
|
||||
if isinstance(next_past_key_values, DynamicCache) or (
|
||||
isinstance(next_past_key_values, EncoderDecoderCache)
|
||||
@ -3033,7 +3025,7 @@ class GenerationMixin:
|
||||
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
||||
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
|
||||
# cache format is standardized, to avoid adding complexity to the codebase.
|
||||
elif "bloom" in model_class or "gptbigcode" in model_class:
|
||||
elif "gptbigcode" in model_class:
|
||||
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
|
||||
raise ValueError(
|
||||
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
|
||||
@ -3161,7 +3153,6 @@ class GenerationMixin:
|
||||
for model_name in [
|
||||
"fsmt",
|
||||
"reformer",
|
||||
"bloom",
|
||||
"ctrl",
|
||||
"gpt_bigcode",
|
||||
"transo_xl",
|
||||
|
@ -24,8 +24,9 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@ -170,7 +171,7 @@ class BloomGelu(nn.Module):
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
@ -191,26 +192,37 @@ class BloomAttention(nn.Module):
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = 1.0
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
||||
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _reshape(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||
storage as `fused_qkv`
|
||||
Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape
|
||||
without making any copies, results share same memory storage as `fused_qkv`
|
||||
|
||||
Args:
|
||||
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
||||
|
||||
Returns:
|
||||
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
||||
value: [batch_size, seq_length, num_heads, head_dim]
|
||||
query: [batch_size, num_heads, seq_length, head_dim]
|
||||
key: [batch_size, num_heads, seq_length, head_dim]
|
||||
value: [batch_size, num_heads, seq_length, head_dim]
|
||||
"""
|
||||
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
||||
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
||||
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
||||
query_layer = fused_qkv[..., 0, :].transpose(1, 2)
|
||||
key_layer = fused_qkv[..., 1, :].transpose(1, 2)
|
||||
value_layer = fused_qkv[..., 2, :].transpose(1, 2)
|
||||
return query_layer, key_layer, value_layer
|
||||
|
||||
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@ -243,35 +255,27 @@ class BloomAttention(nn.Module):
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
layer_past: Optional[Cache] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
batch_size, q_length, _ = hidden_states.shape
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
# 3 x [batch_size, num_heads, seq_length, head_dim]
|
||||
query_layer, key_layer, value_layer = self._reshape(fused_qkv)
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
|
||||
|
||||
_, _, kv_length = key_layer.shape
|
||||
# reshape qkv for further computations
|
||||
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2)
|
||||
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
@ -283,15 +287,13 @@ class BloomAttention(nn.Module):
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
attn_weights = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, :kv_length]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
|
||||
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
@ -322,7 +324,7 @@ class BloomAttention(nn.Module):
|
||||
|
||||
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
outputs = (output_tensor, layer_past)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
@ -361,13 +363,13 @@ class BloomMLP(nn.Module):
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.num_heads = config.n_head
|
||||
self.self_attention = BloomAttention(config)
|
||||
self.self_attention = BloomAttention(config, layer_idx)
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = BloomMLP(config)
|
||||
@ -380,10 +382,11 @@ class BloomBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
layer_past: Optional[Cache] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||
|
||||
@ -406,6 +409,7 @@ class BloomBlock(nn.Module):
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
@ -428,7 +432,7 @@ class BloomBlock(nn.Module):
|
||||
else:
|
||||
outputs = (output,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, attentions
|
||||
return outputs # hidden_states, past_kv, attentions
|
||||
|
||||
|
||||
class BloomPreTrainedModel(PreTrainedModel):
|
||||
@ -437,6 +441,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BloomBlock"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_cache_class = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@ -457,45 +462,6 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_bloom_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
"""
|
||||
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
BLOOM_START_DOCSTRING = r"""
|
||||
|
||||
@ -525,14 +491,23 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||
|
||||
Each element of `past_key_values` is a tuple (past_key, past_value):
|
||||
- past_key: [batch_size * num_heads, head_dim, kv_length]
|
||||
- past_value: [batch_size * num_heads, kv_length, head_dim]
|
||||
Two formats are allowed:
|
||||
- a [`~cache_utils.Cache`] instance;
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||
cache format.
|
||||
|
||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||
legacy cache format will be returned.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||
of shape `(batch_size, sequence_length)`.
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@ -564,6 +539,10 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
|
||||
|
||||
@ -583,7 +562,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -611,7 +590,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
@ -619,6 +598,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
@ -638,62 +618,59 @@ class BloomModel(BloomPreTrainedModel):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
use_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
|
||||
use_legacy_cache = True
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"Using `past_key_values` as a tuple is deprecated and will be removed in v4.45. "
|
||||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
||||
)
|
||||
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
seq_length_with_past = seq_length + past_length
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
|
||||
presents = () if use_cache else None
|
||||
next_decoder_cache = None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
causal_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
causal_mask = causal_mask.bool()
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
for i, block in enumerate(self.h):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@ -703,25 +680,27 @@ class BloomModel(BloomPreTrainedModel):
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
layer_past,
|
||||
past_key_values,
|
||||
head_mask[i],
|
||||
use_cache,
|
||||
output_attentions,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
layer_past=past_key_values,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
if use_cache:
|
||||
next_decoder_cache = outputs[1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
@ -732,16 +711,103 @@ class BloomModel(BloomPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
return tuple(
|
||||
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_length()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
||||
if attention_mask.max() != 0:
|
||||
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
@ -769,39 +835,34 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
cache_position=None,
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
# only last tokens for input_ids if past is not None
|
||||
):
|
||||
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||
if past_key_values is not None:
|
||||
past_length = past_key_values[0][0].shape[2]
|
||||
|
||||
# Some generation methods already pass only the last input ID
|
||||
if input_ids.shape[1] > past_length:
|
||||
remove_prefix_length = past_length
|
||||
else:
|
||||
# Default to old behavior: keep only final ID
|
||||
remove_prefix_length = input_ids.shape[1] - 1
|
||||
|
||||
input_ids = input_ids[:, remove_prefix_length:]
|
||||
|
||||
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
|
||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||
past_key_values = self._convert_to_bloom_cache(past_key_values)
|
||||
if inputs_embeds is not None: # Exception 1
|
||||
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||
input_ids = input_ids[:, cache_position]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
if inputs_embeds is not None and cache_position[0] == 0:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"use_cache": use_cache,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
@ -816,7 +877,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -825,6 +886,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
@ -855,6 +917,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
@ -896,8 +959,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
||||
|
||||
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
@ -907,9 +968,9 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
)
|
||||
for layer_past in standardized_past
|
||||
for layer_past in past
|
||||
)
|
||||
return self._convert_to_bloom_cache(reordered_past)
|
||||
return reordered_past
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
@ -946,7 +1007,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
@ -1083,7 +1144,7 @@ class BloomForTokenClassification(BloomPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
|
@ -2568,13 +2568,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
||||
outputs: ModelOutput,
|
||||
model_kwargs: Dict[str, Any],
|
||||
is_encoder_decoder: bool = False,
|
||||
standardize_cache_format: bool = False,
|
||||
model_inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# update past_key_values
|
||||
cache_name, cache = self._extract_past_from_model_output(
|
||||
outputs, standardize_cache_format=standardize_cache_format
|
||||
)
|
||||
cache_name, cache = self._extract_past_from_model_output(outputs)
|
||||
model_kwargs[cache_name] = cache
|
||||
|
||||
if getattr(outputs, "state", None) is not None:
|
||||
|
@ -252,7 +252,6 @@ class PersimmonAttention(nn.Module):
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||
|
@ -1096,7 +1096,6 @@ class GenerationTesterMixin:
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bloom",
|
||||
"ctrl",
|
||||
"gptbigcode",
|
||||
"transo_xl",
|
||||
@ -1878,7 +1877,7 @@ class GenerationTesterMixin:
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
||||
# complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
||||
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user