diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index 23ecc6d9267..a33a6339b14 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -214,14 +214,19 @@ class BloomOnnxConfig(OnnxConfigWithPast): batch, seqlen = common_inputs["input_ids"].shape # Not using the same length for past_key_values past_key_values_length = seqlen + 2 - past_shape = ( - batch, + head_dim = self._config.hidden_size // self.num_attention_heads + past_key_shape = ( + batch * self.num_attention_heads, + head_dim, past_key_values_length, - self.num_attention_heads, - self._config.hidden_size // self.num_attention_heads, + ) + past_value_shape = ( + batch * self.num_attention_heads, + past_key_values_length, + head_dim, ) ordered_inputs["past_key_values"] = [ - (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers) ] ordered_inputs["attention_mask"] = common_inputs["attention_mask"] diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index afa289afe5b..a33054a3835 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -16,12 +16,13 @@ import math import warnings -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss +from torch.nn import functional as F from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import ( @@ -52,102 +53,100 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: """ - Make causal mask used for bi-directional self-attention. + Make causal mask used for self-attention. """ batch_size, target_length = input_ids_shape - mask = torch.full((target_length, target_length), torch.finfo(dtype).min) - mask_cond = torch.arange(mask.size(-1)) - intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) - mask.masked_fill_(intermediate_mask, 0) - mask = mask.to(dtype) + mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) + # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround + seq_ids = torch.arange(target_length, device=device) + mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] if past_key_values_length > 0: - mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) + mask[:, :past_key_values_length] = False + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) return expanded_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): +def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. """ - batch_size, source_length = mask.size() - tgt_len = tgt_len if tgt_len is not None else source_length + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, 1, tgt_length, src_length) -def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value `softmax(l+a) = softmax(l)`. Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. Args: - Returns tensor shaped (batch_size * n_head, 1, max_seq_len) + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) attention_mask (`torch.Tensor`): Token-wise attention mask, this should be of shape (batch_size, max_seq_len). - n_head (`int`, *required*): + num_heads (`int`, *required*): number of heads dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor - device (`torch.device`, *optional*, default=`torch.device('cpu')`): - device of the output alibi tensor """ - closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) - powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) - if closest_power_of_2 != n_head: + if closest_power_of_2 != num_heads: extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) - num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) - # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) # => the query_length dimension will then be broadcasted correctly # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 - # batch_size = 1, n_head = n_head, query_length - - arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] - alibi = slopes.unsqueeze(-1) * arange_tensor - alibi = alibi * attention_mask[:, None] - return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype) + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) -def dropout_add(x, residual, prob, training): +def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: """ Dropout add function Args: x (`torch.tensor`, *required*): input tensor - residual (`torch.tensor`, *rquired*): + residual (`torch.tensor`, *required*): esidual tensor prob (`float`, *required*): dropout probability training (`bool`, *required*): training mode """ - out = nn.functional.dropout(x, p=prob, training=training) + out = F.dropout(x, p=prob, training=training) out = residual + out return out -def bloom_gelu_forward(x): +def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor: """ Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to make the model jitable. @@ -159,7 +158,7 @@ def bloom_gelu_forward(x): return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) -def bloom_gelu_back(g, x): +def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: """ gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) @@ -179,12 +178,12 @@ def bloom_gelu_back(g, x): class GeLUFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input): + def forward(ctx, input: torch.Tensor) -> torch.Tensor: ctx.save_for_backward(input) return bloom_gelu_forward(input) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: input = ctx.saved_tensors tmp = bloom_gelu_back(grad_output, input) return tmp @@ -197,13 +196,12 @@ class BloomGelu(nn.Module): copied from Megatron-DeepSpeed code and adapted for our needs See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329 - """ def __init__(self): super().__init__() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training: return GeLUFunction.apply(x) else: @@ -211,7 +209,7 @@ class BloomGelu(nn.Module): class BloomAttention(nn.Module): - def __init__(self, config, layer_number=None): + def __init__(self, config: BloomConfig): super().__init__() self.pretraining_tp = config.pretraining_tp @@ -230,106 +228,131 @@ class BloomAttention(nn.Module): ) # Layer-wise attention scaling - self.layer_number = max(1, layer_number) - self.norm_factor = math.sqrt(self.head_dim) * self.layer_number + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = 1.0 self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.attention_dropout = nn.Dropout(config.attention_dropout) - def _split_heads(self, fused_qkv): + def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Split the last dimension into (num_heads, head_dim) - """ - new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) - # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) - # fused_qkv = fused_qkv.transpose(1, 0) - fused_qkv = fused_qkv.reshape(new_tensor_shape) - # fused_qkv = fused_qkv.permute(0, 2, 1, 3) - return torch.split(fused_qkv, self.head_dim, -1) + Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory + storage as `fused_qkv` - def _merge_heads(self, x): + Args: + fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] + + Returns: + query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] + value: [batch_size, seq_length, num_heads, head_dim] + """ + batch_size, seq_length, three_times_hidden_size = fused_qkv.shape + fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) + return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + """ + Merge heads together over the last dimenstion + + Args: + x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim] + + Returns: + torch.tensor: [batch_size, seq_length, num_heads * head_dim] + """ # What we want to achieve is: - # batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim + # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim + batch_size_and_num_heads, seq_length, _ = x.shape + batch_size = batch_size_and_num_heads // self.num_heads # First view to decompose the batch size - # batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim - x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim) + # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim + x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) - # batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim + # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim x = x.permute(0, 2, 1, 3) - # batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim - return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim) + # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim + return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) def forward( self, - hidden_states, - residual, - layer_past=None, - attention_mask=None, - alibi=None, - head_mask=None, - use_cache=False, - output_attentions=False, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, ): - alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) if layer_past is not None: past_key, past_value = layer_past - # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, head_dim, kv_length] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=2) + value_layer = torch.cat((past_value, value_layer), dim=1) + + _, _, kv_length = key_layer.shape if use_cache is True: present = (key_layer, value_layer) else: present = None - beta = 1.0 / self.layer_number + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) - # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] - matmul_result = (1.0 / self.norm_factor) * torch.bmm( - query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]), - key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]), - ) + beta * alibi + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) - # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) - - # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype - attn_weights = (attention_scores * self.layer_number) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) - attention_probs = attention_probs * (~attention_mask.to(torch.bool)) - # [batch_size, num_heads, q_length, k_length] + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask - # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(matmul_result.shape) + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = torch.bmm( - attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3)) - ) + context_layer = torch.bmm(attention_probs_reshaped, value_layer) # change view [batch_size, num_heads, q_length, head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: - slices = context_layer.shape[-1] / self.pretraining_tp + slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): - output_tensor = output_tensor + nn.functional.linear( + output_tensor = output_tensor + F.linear( context_layer[:, :, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) @@ -346,7 +369,7 @@ class BloomAttention(nn.Module): class BloomMLP(nn.Module): - def __init__(self, config): + def __init__(self, config: BloomConfig): super().__init__() hidden_size = config.hidden_size @@ -357,14 +380,14 @@ class BloomMLP(nn.Module): self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) self.hidden_dropout = config.hidden_dropout - def forward(self, hidden_states, residual): + def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) if self.pretraining_tp > 1 and self.slow_but_exact: intermediate_output = torch.zeros_like(residual) slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp for i in range(self.pretraining_tp): - intermediate_output = intermediate_output + nn.functional.linear( + intermediate_output = intermediate_output + F.linear( hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], ) @@ -377,13 +400,13 @@ class BloomMLP(nn.Module): class BloomBlock(nn.Module): - def __init__(self, config, layer_number=None): + def __init__(self, config: BloomConfig): super().__init__() hidden_size = config.hidden_size self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.n_head = config.n_head - self.self_attention = BloomAttention(config, layer_number=layer_number) + self.num_heads = config.n_head + self.self_attention = BloomAttention(config) self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config) @@ -393,13 +416,13 @@ class BloomBlock(nn.Module): def forward( self, - hidden_states, - layer_past=None, - attention_mask=None, - head_mask=None, - use_cache=False, - output_attentions=False, - alibi=None, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, ): # hidden_states: [batch_size, seq_length, hidden_size] @@ -462,9 +485,9 @@ class BloomPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) - def _init_weights(self, module): + def _init_weights(self, module: nn.Module): """Initialize the weights.""" - if isinstance(module, (nn.Linear)): + if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @@ -478,7 +501,7 @@ class BloomPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False): if isinstance(module, BloomModel): module.gradient_checkpointing = value @@ -501,9 +524,8 @@ BLOOM_START_DOCSTRING = r""" BLOOM_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input - sequence tokens in the vocabulary. + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]` + (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`. @@ -516,6 +538,10 @@ BLOOM_INPUTS_DOCSTRING = r""" Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed. + + Each element of `past_key_values` is a tuple (past_key, past_value): + - past_key: [batch_size * num_heads, head_dim, kv_length] + - past_value: [batch_size * num_heads, kv_length, head_dim] attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -555,19 +581,18 @@ BLOOM_INPUTS_DOCSTRING = r""" BLOOM_START_DOCSTRING, ) class BloomModel(BloomPreTrainedModel): - def __init__(self, config): + def __init__(self, config: BloomConfig): super().__init__(config) self.embed_dim = config.hidden_size - self.n_head = config.n_head + self.num_heads = config.n_head # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks - self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)]) # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -580,25 +605,29 @@ class BloomModel(BloomPreTrainedModel): def get_input_embeddings(self): return self.word_embeddings - def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_attn_mask( + self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int + ) -> torch.BoolTensor: # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(attention_mask.device) + device = attention_mask.device + _, src_length = input_shape - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + if src_length > 1: + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length ) + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + return combined_attention_mask - def set_input_embeddings(self, new_embeddings): + def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @@ -610,17 +639,17 @@ class BloomModel(BloomPreTrainedModel): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **deprecated_arguments - ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( @@ -641,10 +670,9 @@ class BloomModel(BloomPreTrainedModel): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) + batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") @@ -653,8 +681,8 @@ class BloomModel(BloomPreTrainedModel): # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_head x N x N - # head_mask has shape n_layer x batch x n_head x N x N + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: @@ -662,27 +690,28 @@ class BloomModel(BloomPreTrainedModel): hidden_states = self.word_embeddings_layernorm(inputs_embeds) - output_shape = input_shape + (hidden_states.size(-1),) - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation - current_sequence_length = hidden_states.shape[1] + seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[1] - current_sequence_length += past_key_values_length - + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: - attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device) + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) - alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device) + alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -700,14 +729,14 @@ class BloomModel(BloomPreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions, alibi) + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, - None, + alibi, causal_mask, head_mask[i], ) @@ -735,8 +764,6 @@ class BloomModel(BloomPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states = hidden_states.view(output_shape) - if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -758,7 +785,7 @@ class BloomModel(BloomPreTrainedModel): class BloomForCausalLM(BloomPreTrainedModel): _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] - def __init__(self, config): + def __init__(self, config: BloomConfig): super().__init__(config) self.transformer = BloomModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) @@ -769,16 +796,20 @@ class BloomForCausalLM(BloomPreTrainedModel): def get_output_embeddings(self): return self.lm_head - def set_output_embeddings(self, new_embeddings): + def set_output_embeddings(self, new_embeddings: torch.Tensor): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None if past: input_ids = input_ids[:, -1].unsqueeze(-1) - attention_mask = kwargs.get("attention_mask", None) - return { "input_ids": input_ids, "past_key_values": past, @@ -795,16 +826,16 @@ class BloomForCausalLM(BloomPreTrainedModel): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **deprecated_arguments ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" @@ -845,9 +876,12 @@ class BloomForCausalLM(BloomPreTrainedModel): # Shift so that tokens < n predict n shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape # Flatten the tokens loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = loss_fct( + shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -862,14 +896,36 @@ class BloomForCausalLM(BloomPreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. + + Output shares the same memory storage as `past`. """ + batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape + batch_size = len(beam_idx) + num_heads = batch_size_times_num_heads // batch_size + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past + } + # key: layer_past[0] [batch_size * num_heads, head_dim, seq_length] + # value: layer_past[1] [batch_size * num_heads, seq_length, head_dim] return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + ( + layer_past[0] + .view(batch_size, num_heads, head_dim, seq_length) + .index_select(0, device_to_beam_idx[layer_past[0].device]) + .view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1] + .view(batch_size, num_heads, seq_length, head_dim) + .index_select(0, device_to_beam_idx[layer_past[0].device]) + .view(batch_size_times_num_heads, seq_length, head_dim), + ) for layer_past in past ) @@ -892,7 +948,7 @@ class BloomForCausalLM(BloomPreTrainedModel): class BloomForSequenceClassification(BloomPreTrainedModel): _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] - def __init__(self, config): + def __init__(self, config: BloomConfig): super().__init__(config) self.num_labels = config.num_labels self.transformer = BloomModel(config) @@ -910,16 +966,16 @@ class BloomForSequenceClassification(BloomPreTrainedModel): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **deprecated_arguments ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: r""" @@ -966,7 +1022,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): sequence_lengths = -1 else: if input_ids is not None: - sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1 else: sequence_lengths = -1 logger.warning( @@ -994,7 +1050,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) @@ -1021,7 +1077,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel): class BloomForTokenClassification(BloomPreTrainedModel): _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] - def __init__(self, config): + def __init__(self, config: BloomConfig): super().__init__(config) self.num_labels = config.num_labels @@ -1047,16 +1103,16 @@ class BloomForTokenClassification(BloomPreTrainedModel): ) def forward( self, - input_ids=None, - past_key_values=None, - attention_mask=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, **deprecated_arguments ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: r""" @@ -1095,8 +1151,11 @@ class BloomForTokenClassification(BloomPreTrainedModel): loss = None if labels is not None: + batch_size, seq_length = labels.shape loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) if not return_dict: output = (logits,) + transformer_outputs[2:]