mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Bloom Optimize operations (#17866)
* fix tolerance for a bloom slow test * enhance alibi padding - get rid of for loops - deals better with padded batched input - avoid useless cpu/gpu communication when creating alibi Co-authored-by: justheuristic <justheuristic@gmail.com> * optimize attention mask * fix scaled softmax limit values * optimize building alibi tensor Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * fix attention_mask shape when it's None * minor fixes - fix docstring + arg names * remove colons in docstring * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * apply suggestion * remove unsued arg * refactor a bit - use [:, None] for consistency * refactor attention block Co-authored-by: Nouamane Tazi <nouamane98@gmail.com> * quick fixes * first attempt * refactor attention block and fix all tests except "test_simple_generation" - added comments to better explain attention block * remove debug lines and add TODO comment * change `torch.bmm` to `torch.baddbmm` - fixes `test_simple_generation`but breaks `test_batch_generation_padd` * styling * all tests are passing now - use `bmm` - add explanation for `allow_fp16_reduced_precision_reduction` Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * styling Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * fix support for accelerate Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * remove attn softmax in fp32 * refactor comments * refactor a bit - remove warning message - remove print on test * refer to pytorch t5 * change the slow tests - do the tests in fp32 - remove some comments - keep large comments * update expected output for `test_simple_generation` - we now test using fp32 * make style + change comments a bit * fix dtype padd test Co-authored-by: justheuristic <justheuristic@gmail.com> Co-authored-by: Nouamane Tazi <nouamane98@gmail.com> Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
5ff6f853d7
commit
a462fc9232
@ -72,9 +72,6 @@ class BloomConfig(PretrainedConfig):
|
||||
If set to `True`, it will skip bias add for each linear layer in the transformer blocks
|
||||
skip_bias_add_qkv (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, it will skip bias add for the first linear layer in the transformer blocks
|
||||
attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True` and the `dtype` is set to `float16` it will scale the input of the Softmax function to
|
||||
`fp32`
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
||||
Dropout rate of the dropout function on the bias dropout.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.1):
|
||||
@ -128,7 +125,6 @@ class BloomConfig(PretrainedConfig):
|
||||
hidden_size=64,
|
||||
n_layer=2,
|
||||
n_head=8,
|
||||
masked_softmax_fusion=True,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=False,
|
||||
@ -137,7 +133,6 @@ class BloomConfig(PretrainedConfig):
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
attention_softmax_in_fp32=True,
|
||||
pretraining_tp=1, # TP rank used when training with megatron
|
||||
dtype="bfloat16",
|
||||
slow_but_exact=False,
|
||||
@ -147,7 +142,6 @@ class BloomConfig(PretrainedConfig):
|
||||
self.hidden_size = hidden_size
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.masked_softmax_fusion = masked_softmax_fusion
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
@ -155,7 +149,6 @@ class BloomConfig(PretrainedConfig):
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
@ -51,49 +51,38 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
||||
"""Split a tensor along its last dimension.
|
||||
|
||||
Args:
|
||||
tensor: ([`torch.tensor`], *required*):
|
||||
input tensor to split
|
||||
num_partitions ([`int`], *required*):
|
||||
number of partitions to split the tensor
|
||||
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
|
||||
If True, make each chunk contiguous in memory.
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
numerator, denominator = tensor.size()[last_dim], num_partitions
|
||||
if not (numerator % denominator == 0):
|
||||
raise ValueError(f"{numerator} is not divisible by {denominator}")
|
||||
last_dim_size = numerator // denominator
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
return tensor_list
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1)
|
||||
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
||||
attention_mask_bool = ~attention_mask.bool()
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
|
||||
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
|
||||
padded_causal_mask = torch.logical_or(
|
||||
attention_mask_bool[:, None, key_length - query_length : key_length, None],
|
||||
~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
|
||||
)
|
||||
padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
|
||||
# Make use of floats
|
||||
return (
|
||||
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
|
||||
padded_causal_mask,
|
||||
)
|
||||
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16):
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
@ -101,73 +90,41 @@ def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16):
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (n_head, 1, max_seq_len)
|
||||
max_seq_len: (`int`, *required*):
|
||||
max sequence length
|
||||
n_head: (`int`, *required*):
|
||||
Returns tensor shaped (batch_size * n_head, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
n_head (`int`, *required*):
|
||||
number of heads
|
||||
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
device (`torch.device`, *optional*, default=`torch.device('cpu')`):
|
||||
device of the output alibi tensor
|
||||
"""
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
|
||||
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
|
||||
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
|
||||
|
||||
alibi = alibi.to(device=device, dtype=dtype)
|
||||
|
||||
return alibi
|
||||
|
||||
|
||||
def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
|
||||
"""
|
||||
Args:
|
||||
Pre-process the alibi tensor for padding.
|
||||
alibi: ([`torch.tensor`], *required*):
|
||||
alibi tensor to pre-process
|
||||
attention_mask: ([`torch.tensor`], *required*):
|
||||
attention mask to pre-process"""
|
||||
|
||||
# Sanity check if we are not inferring less tokens than the total sequence length
|
||||
# This usually happens when the inference is done with past_key_values
|
||||
# In this case we re-create the alibi tensor with the correct sequence length
|
||||
if attention_mask.shape[-1] != alibi.shape[-1]:
|
||||
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.device, alibi.dtype).repeat(
|
||||
attention_mask.shape[0], 1, 1
|
||||
if closest_power_of_2 != n_head:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||
)
|
||||
# Get the indexes of the padding tokens
|
||||
index_x0, index_y0 = torch.where(attention_mask == 0.0)
|
||||
index_x1, index_y1 = torch.where(attention_mask == 1.0)
|
||||
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Clone the embeddings - we can detach because the embeddings are not learned
|
||||
# Get a refence tensor
|
||||
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.device, alibi.dtype)
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
# This is more or less identical to T5's relative position bias:
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
# batch_size = 1, n_head = n_head, query_length
|
||||
|
||||
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
|
||||
# Only where you do not have padding. Replace padding tokens by zeros
|
||||
# This operation can be seen as a shifting operation.
|
||||
for i, index in enumerate(torch.unique(index_x0)):
|
||||
slice_to_modify = torch.zeros_like(slice_reference_alibi)
|
||||
index_shift = index_y1[index_x1 == index]
|
||||
shift_value = len(index_shift)
|
||||
slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
|
||||
alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
|
||||
return alibi
|
||||
arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None]
|
||||
alibi = slopes.unsqueeze(-1) * arange_tensor
|
||||
alibi = alibi * attention_mask[:, None]
|
||||
return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype)
|
||||
|
||||
|
||||
def dropout_add(x, residual, prob, training):
|
||||
@ -252,58 +209,6 @@ class BloomGelu(nn.Module):
|
||||
return bloom_gelu_forward(x)
|
||||
|
||||
|
||||
class BloomScaledSoftmax(nn.Module):
|
||||
"""
|
||||
fused operation: scaling + mask + softmax
|
||||
|
||||
Args:
|
||||
input_in_fp16 (`bool`, *required*):
|
||||
flag to indicate if input in fp16 data format.
|
||||
input_in_bf16 (`bool`, *required*):
|
||||
flag to indicate if input in bf16 data format.
|
||||
scaled_masked_softmax_fusion (`bool`, *required*):
|
||||
flag to indicate user want to use softmax fusion
|
||||
mask_func (`function`, *required*):
|
||||
mask function to be applied.
|
||||
softmax_in_fp32 (`bool`, *required*):
|
||||
if true, softmax in performed at fp32 precision.
|
||||
scale (`float`, *required*):
|
||||
scaling factor used in input tensor scaling.
|
||||
"""
|
||||
|
||||
def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
|
||||
super().__init__()
|
||||
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
|
||||
self.mask_func = mask_func
|
||||
self.softmax_in_fp32 = softmax_in_fp32
|
||||
self.scale = scale
|
||||
|
||||
if not (self.scale is None or softmax_in_fp32):
|
||||
raise ValueError("softmax should be in fp32 when scaled")
|
||||
|
||||
def forward(self, input, mask, max_positions):
|
||||
input_dtype = input.dtype
|
||||
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
|
||||
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
|
||||
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
||||
|
||||
mask = mask.to(input.device)
|
||||
seq_ids = torch.arange(max_positions, device=input.device)
|
||||
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
|
||||
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
||||
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
||||
|
||||
if input_in_16bit and self.softmax_in_fp32:
|
||||
probs = probs.to(dtype=input_dtype)
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
def __init__(self, config, layer_number=None):
|
||||
super().__init__()
|
||||
@ -315,8 +220,6 @@ class BloomAttention(nn.Module):
|
||||
self.num_heads = config.n_head
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.split_size = self.hidden_size
|
||||
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
||||
self.masked_softmax_fusion = config.masked_softmax_fusion
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
@ -329,18 +232,35 @@ class BloomAttention(nn.Module):
|
||||
self.layer_number = max(1, layer_number)
|
||||
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
|
||||
|
||||
# Scaled Softmax
|
||||
self.scale_mask_softmax = BloomScaledSoftmax(
|
||||
self.masked_softmax_fusion,
|
||||
attention_mask_func,
|
||||
self.attention_softmax_in_fp32,
|
||||
self.layer_number,
|
||||
)
|
||||
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
def _split_heads(self, fused_qkv):
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim)
|
||||
"""
|
||||
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
||||
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
|
||||
# fused_qkv = fused_qkv.transpose(1, 0)
|
||||
fused_qkv = fused_qkv.reshape(*new_tensor_shape)
|
||||
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
|
||||
return torch.split(fused_qkv, self.head_dim, -1)
|
||||
|
||||
def _merge_heads(self, x):
|
||||
# What we want to achieve is:
|
||||
# batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim
|
||||
|
||||
# First view to decompose the batch size
|
||||
# batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim
|
||||
x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim)
|
||||
|
||||
# batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim
|
||||
return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -352,25 +272,15 @@ class BloomAttention(nn.Module):
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||
# repeat alibi tensor with the batch size
|
||||
alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)
|
||||
alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# apply preprocessing if the input is padded
|
||||
if attention_mask is not None and 0 in attention_mask:
|
||||
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
|
||||
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
|
||||
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
|
||||
|
||||
@ -379,66 +289,39 @@ class BloomAttention(nn.Module):
|
||||
else:
|
||||
present = None
|
||||
|
||||
# [batch_size, head_dim, q_length, k_length]
|
||||
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
|
||||
|
||||
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
|
||||
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
|
||||
|
||||
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
|
||||
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
# slice alibi tensor until the query length
|
||||
sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
|
||||
|
||||
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
|
||||
beta = 1.0 / self.layer_number
|
||||
|
||||
matmul_result = torch.baddbmm(
|
||||
sliced_alibi,
|
||||
query_layer.transpose(1, 0),
|
||||
key_layer.transpose(1, 0).transpose(1, 2),
|
||||
beta=beta,
|
||||
alpha=(1.0 / self.norm_factor),
|
||||
)
|
||||
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length]
|
||||
matmul_result = (1.0 / self.norm_factor) * torch.bmm(
|
||||
query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]),
|
||||
key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]),
|
||||
) + beta * alibi
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, k_length]
|
||||
attention_scores = matmul_result.view(*output_size)
|
||||
attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2))
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
|
||||
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
|
||||
value_layer.dtype
|
||||
)
|
||||
# We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
attn_weights = (attention_scores * self.layer_number) + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
attention_probs = attention_probs * (~attention_mask.bool())
|
||||
# [batch_size, num_heads, q_length, k_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# context layer shape: [batch_size, num_heads, q_length, head_dim]
|
||||
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
||||
|
||||
# change view [k_length, batch_size x num_heads, head_dim]
|
||||
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
|
||||
|
||||
# change view [batch_size x num_heads, q_length, k_length]
|
||||
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
||||
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
|
||||
context_layer = torch.bmm(
|
||||
attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))
|
||||
)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
|
||||
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
|
||||
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# Output. [q_length, batch_size, hidden_size]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
@ -452,11 +335,9 @@ class BloomAttention(nn.Module):
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
output = output_tensor.transpose(1, 0)
|
||||
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
|
||||
output = dropout_add(output, residual, self.hidden_dropout, self.training)
|
||||
|
||||
outputs = (output, present)
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
@ -703,6 +584,24 @@ class BloomModel(BloomPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.word_embeddings
|
||||
|
||||
def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(attention_mask.device)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.word_embeddings = new_embeddings
|
||||
|
||||
@ -765,9 +664,19 @@ class BloomModel(BloomPreTrainedModel):
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
current_sequence_length += past_key_values[0][0].shape[1]
|
||||
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.device, hidden_states.dtype)
|
||||
past_key_values_length = past_key_values[0][0].shape[1]
|
||||
current_sequence_length += past_key_values_length
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
@ -793,14 +702,14 @@ class BloomModel(BloomPreTrainedModel):
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
causal_mask,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
@ -877,7 +786,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
|
@ -377,15 +377,34 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_simple_generation(self):
|
||||
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
|
||||
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
|
||||
# We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms
|
||||
# This discrepancy is observed only when using small models and seems to be stable for larger models.
|
||||
# Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models.
|
||||
|
||||
# Here is a summary of an ablation study of our observations
|
||||
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a"
|
||||
# 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
|
||||
# 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS
|
||||
# 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS
|
||||
# 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL
|
||||
|
||||
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love"
|
||||
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False)
|
||||
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS
|
||||
# >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
|
||||
|
||||
path_350m = "bigscience/bloom-350m"
|
||||
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
|
||||
model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
|
||||
model = model.eval()
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(path_350m)
|
||||
|
||||
input_sentence = "I enjoy walking with my cute dog"
|
||||
# This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU
|
||||
EXPECTED_OUTPUT = (
|
||||
"I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am"
|
||||
" a very good listener. I am a very good person, and I am a very good person. I am a"
|
||||
"I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very "
|
||||
"active person, and I enjoy working out, and I am a very active person. I am a very active person, and I"
|
||||
)
|
||||
|
||||
input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
|
||||
@ -397,7 +416,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
@require_torch_gpu
|
||||
def test_batch_generation(self):
|
||||
path_350m = "bigscience/bloom-350m"
|
||||
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
|
||||
model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
|
||||
model = model.eval()
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
|
||||
|
||||
@ -416,8 +435,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_batch_generation_padd(self):
|
||||
|
||||
path_350m = "bigscience/bloom-350m"
|
||||
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
|
||||
model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
|
||||
model = model.eval()
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user