Bloom Optimize operations (#17866)

* fix tolerance for a bloom slow test

* enhance alibi padding

- get rid of for loops
- deals better with padded batched input
- avoid useless cpu/gpu communication when creating alibi

Co-authored-by: justheuristic <justheuristic@gmail.com>

* optimize attention mask

* fix scaled softmax limit values

* optimize building alibi tensor

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix attention_mask shape when it's None

* minor fixes

- fix docstring + arg names

* remove colons in docstring

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* apply suggestion

* remove unsued arg

* refactor a bit

- use [:, None] for consistency

* refactor attention block

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>

* quick fixes

* first attempt

* refactor attention block and fix all tests except "test_simple_generation"

- added comments to better explain attention block

* remove debug lines and add TODO comment

* change `torch.bmm` to `torch.baddbmm`
- fixes `test_simple_generation`but breaks `test_batch_generation_padd`

* styling

* all tests are passing now
- use `bmm`
- add explanation for `allow_fp16_reduced_precision_reduction`

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* styling

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* fix support for accelerate

Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove attn softmax in fp32

* refactor comments

* refactor a bit

- remove warning message
- remove print on test

* refer to pytorch t5

* change the slow tests

- do the tests in fp32
- remove some comments
- keep large comments

* update expected output for `test_simple_generation`
- we now test using fp32

* make style + change comments a bit

* fix dtype padd test

Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
Co-authored-by: Younes Belkada <younesbelkada@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2022-07-11 19:16:13 +02:00 committed by GitHub
parent 5ff6f853d7
commit a462fc9232
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 239 deletions

View File

@ -72,9 +72,6 @@ class BloomConfig(PretrainedConfig):
If set to `True`, it will skip bias add for each linear layer in the transformer blocks
skip_bias_add_qkv (`bool`, *optional*, defaults to `False`):
If set to `True`, it will skip bias add for the first linear layer in the transformer blocks
attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`):
If set to `True` and the `dtype` is set to `float16` it will scale the input of the Softmax function to
`fp32`
hidden_dropout (`float`, *optional*, defaults to 0.1):
Dropout rate of the dropout function on the bias dropout.
attention_dropout (`float`, *optional*, defaults to 0.1):
@ -128,7 +125,6 @@ class BloomConfig(PretrainedConfig):
hidden_size=64,
n_layer=2,
n_head=8,
masked_softmax_fusion=True,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=False,
@ -137,7 +133,6 @@ class BloomConfig(PretrainedConfig):
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
attention_softmax_in_fp32=True,
pretraining_tp=1, # TP rank used when training with megatron
dtype="bfloat16",
slow_but_exact=False,
@ -147,7 +142,6 @@ class BloomConfig(PretrainedConfig):
self.hidden_size = hidden_size
self.n_layer = n_layer
self.n_head = n_head
self.masked_softmax_fusion = masked_softmax_fusion
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
@ -155,7 +149,6 @@ class BloomConfig(PretrainedConfig):
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

View File

@ -51,49 +51,38 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Args:
tensor: ([`torch.tensor`], *required*):
input tensor to split
num_partitions ([`int`], *required*):
number of partitions to split the tensor
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
If True, make each chunk contiguous in memory.
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
numerator, denominator = tensor.size()[last_dim], num_partitions
if not (numerator % denominator == 0):
raise ValueError(f"{numerator} is not divisible by {denominator}")
last_dim_size = numerator // denominator
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
mask.masked_fill_(intermediate_mask, 0)
mask = mask.to(dtype)
return tensor_list
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1)
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask
def attention_mask_func(attention_scores, attention_mask, causal_mask):
attention_mask_bool = ~attention_mask.bool()
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
batch_size, source_length = mask.size()
tgt_len = tgt_len if tgt_len is not None else source_length
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
padded_causal_mask = torch.logical_or(
attention_mask_bool[:, None, key_length - query_length : key_length, None],
~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
)
padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
# Make use of floats
return (
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
padded_causal_mask,
)
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16):
def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
@ -101,73 +90,41 @@ def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16):
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
Args:
Returns tensor shaped (n_head, 1, max_seq_len)
max_seq_len: (`int`, *required*):
max sequence length
n_head: (`int`, *required*):
Returns tensor shaped (batch_size * n_head, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
n_head (`int`, *required*):
number of heads
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
device (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
"""
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
alibi = slopes * arange_tensor.expand(n_head, -1, -1)
alibi = alibi.to(device=device, dtype=dtype)
return alibi
def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
"""
Args:
Pre-process the alibi tensor for padding.
alibi: ([`torch.tensor`], *required*):
alibi tensor to pre-process
attention_mask: ([`torch.tensor`], *required*):
attention mask to pre-process"""
# Sanity check if we are not inferring less tokens than the total sequence length
# This usually happens when the inference is done with past_key_values
# In this case we re-create the alibi tensor with the correct sequence length
if attention_mask.shape[-1] != alibi.shape[-1]:
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.device, alibi.dtype).repeat(
attention_mask.shape[0], 1, 1
if closest_power_of_2 != n_head:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
# Get the indexes of the padding tokens
index_x0, index_y0 = torch.where(attention_mask == 0.0)
index_x1, index_y1 = torch.where(attention_mask == 1.0)
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Clone the embeddings - we can detach because the embeddings are not learned
# Get a refence tensor
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.device, alibi.dtype)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
# batch_size = 1, n_head = n_head, query_length
# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
# Only where you do not have padding. Replace padding tokens by zeros
# This operation can be seen as a shifting operation.
for i, index in enumerate(torch.unique(index_x0)):
slice_to_modify = torch.zeros_like(slice_reference_alibi)
index_shift = index_y1[index_x1 == index]
shift_value = len(index_shift)
slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
return alibi
arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None]
alibi = slopes.unsqueeze(-1) * arange_tensor
alibi = alibi * attention_mask[:, None]
return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype)
def dropout_add(x, residual, prob, training):
@ -252,58 +209,6 @@ class BloomGelu(nn.Module):
return bloom_gelu_forward(x)
class BloomScaledSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Args:
input_in_fp16 (`bool`, *required*):
flag to indicate if input in fp16 data format.
input_in_bf16 (`bool`, *required*):
flag to indicate if input in bf16 data format.
scaled_masked_softmax_fusion (`bool`, *required*):
flag to indicate user want to use softmax fusion
mask_func (`function`, *required*):
mask function to be applied.
softmax_in_fp32 (`bool`, *required*):
if true, softmax in performed at fp32 precision.
scale (`float`, *required*):
scaling factor used in input tensor scaling.
"""
def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
super().__init__()
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise ValueError("softmax should be in fp32 when scaled")
def forward(self, input, mask, max_positions):
input_dtype = input.dtype
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
if self.scale is not None:
input = input * self.scale
if mask is None:
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
seq_ids = torch.arange(max_positions, device=input.device)
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)
return probs
class BloomAttention(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()
@ -315,8 +220,6 @@ class BloomAttention(nn.Module):
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.masked_softmax_fusion = config.masked_softmax_fusion
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
@ -329,18 +232,35 @@ class BloomAttention(nn.Module):
self.layer_number = max(1, layer_number)
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
# Scaled Softmax
self.scale_mask_softmax = BloomScaledSoftmax(
self.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
self.layer_number,
)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def _split_heads(self, fused_qkv):
"""
Split the last dimension into (num_heads, head_dim)
"""
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
# fused_qkv = fused_qkv.transpose(1, 0)
fused_qkv = fused_qkv.reshape(*new_tensor_shape)
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
return torch.split(fused_qkv, self.head_dim, -1)
def _merge_heads(self, x):
# What we want to achieve is:
# batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim
# First view to decompose the batch size
# batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim
x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim)
# batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim
return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim)
def forward(
self,
hidden_states,
@ -352,25 +272,15 @@ class BloomAttention(nn.Module):
use_cache=False,
output_attentions=False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# repeat alibi tensor with the batch size
alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)
alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# apply preprocessing if the input is padded
if attention_mask is not None and 0 in attention_mask:
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
mixed_x_layer = self.query_key_value(hidden_states)
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
@ -379,66 +289,39 @@ class BloomAttention(nn.Module):
else:
present = None
# [batch_size, head_dim, q_length, k_length]
output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
# slice alibi tensor until the query length
sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
beta = 1.0 / self.layer_number
matmul_result = torch.baddbmm(
sliced_alibi,
query_layer.transpose(1, 0),
key_layer.transpose(1, 0).transpose(1, 2),
beta=beta,
alpha=(1.0 / self.norm_factor),
)
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length]
matmul_result = (1.0 / self.norm_factor) * torch.bmm(
query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]),
key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]),
) + beta * alibi
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)
attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2))
# attention scores and attention mask [b, np, sq, sk]
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
value_layer.dtype
)
# We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length]
input_dtype = attention_scores.dtype
attn_weights = (attention_scores * self.layer_number) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
attention_probs = attention_probs * (~attention_mask.bool())
# [batch_size, num_heads, q_length, k_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# context layer shape: [batch_size, num_heads, q_length, head_dim]
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
# change view [k_length, batch_size x num_heads, head_dim]
value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
attention_probs_reshaped = attention_probs.view(*matmul_result.shape)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
context_layer = torch.bmm(
attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))
)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(*output_size)
# [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [q_length, batch_size, hidden_size]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
@ -452,11 +335,9 @@ class BloomAttention(nn.Module):
else:
output_tensor = self.dense(context_layer)
output = output_tensor.transpose(1, 0)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
output = dropout_add(output, residual, self.hidden_dropout, self.training)
outputs = (output, present)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
@ -703,6 +584,24 @@ class BloomModel(BloomPreTrainedModel):
def get_input_embeddings(self):
return self.word_embeddings
def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(attention_mask.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings):
self.word_embeddings = new_embeddings
@ -765,9 +664,19 @@ class BloomModel(BloomPreTrainedModel):
# Compute alibi tensor: check build_alibi_tensor documentation
current_sequence_length = hidden_states.shape[1]
past_key_values_length = 0
if past_key_values[0] is not None:
current_sequence_length += past_key_values[0][0].shape[1]
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.device, hidden_states.dtype)
past_key_values_length = past_key_values[0][0].shape[1]
current_sequence_length += past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device)
causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -793,14 +702,14 @@ class BloomModel(BloomPreTrainedModel):
create_custom_forward(block),
hidden_states,
None,
attention_mask,
causal_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
@ -877,7 +786,6 @@ class BloomForCausalLM(BloomPreTrainedModel):
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}

View File

@ -377,15 +377,34 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
@slow
@require_torch_gpu
def test_simple_generation(self):
# This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations
# do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200
# We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms
# This discrepancy is observed only when using small models and seems to be stable for larger models.
# Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models.
# Here is a summary of an ablation study of our observations
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a"
# 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
# 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS
# 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS
# 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL
# EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love"
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False)
# >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS
# >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS
path_350m = "bigscience/bloom-350m"
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_350m)
input_sentence = "I enjoy walking with my cute dog"
# This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU
EXPECTED_OUTPUT = (
"I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am"
" a very good listener. I am a very good person, and I am a very good person. I am a"
"I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very "
"active person, and I enjoy working out, and I am a very active person. I am a very active person, and I"
)
input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
@ -397,7 +416,7 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
@require_torch_gpu
def test_batch_generation(self):
path_350m = "bigscience/bloom-350m"
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")
@ -416,8 +435,9 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
@slow
@require_torch_gpu
def test_batch_generation_padd(self):
path_350m = "bigscience/bloom-350m"
model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda()
model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda()
model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")