mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[BLOOM] Clean modeling code (#18344)
* Cleanup some code * Improve signatures * Try to reduce the number of reshape/copies * I don't think we actually need the layer_num scaling trick * No need for duplication * Try to fix beam_search * Fix beam search * Removing layer num normalization seems to be breaking * Not sure self.layer_number normalization actually matters * Try and be backward compatible * Try to fix beam_search * Revert attempt to be backward compatible * Improve documentation on past_key_values format * Optimize the device allocation in case of hidden_states in multiple devices * No need to manually cast the values to a specific device * Rename with long version of variables * Improve type hinting * Add comment that explains that some methods return views * Actually i think the attention casting only makes sense when we use torch.float16 * We don't actually need layer_number to be passed anymore * Fix FX test * Bypass torch.baddbmm * Apply suggestions from code review * Add comment about support for torchScript v1.11 * fix ONNX support for bloom (#18456) Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com> Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
This commit is contained in:
parent
02b176c4ce
commit
b69a62d579
@ -214,14 +214,19 @@ class BloomOnnxConfig(OnnxConfigWithPast):
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
past_shape = (
|
||||
batch,
|
||||
head_dim = self._config.hidden_size // self.num_attention_heads
|
||||
past_key_shape = (
|
||||
batch * self.num_attention_heads,
|
||||
head_dim,
|
||||
past_key_values_length,
|
||||
self.num_attention_heads,
|
||||
self._config.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
past_value_shape = (
|
||||
batch * self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
head_dim,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
||||
(torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
|
@ -16,12 +16,13 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ...modeling_outputs import (
|
||||
@ -52,102 +53,100 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
Make causal mask used for self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
|
||||
mask_cond = torch.arange(mask.size(-1))
|
||||
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
|
||||
mask.masked_fill_(intermediate_mask, 0)
|
||||
mask = mask.to(dtype)
|
||||
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
||||
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
||||
seq_ids = torch.arange(target_length, device=device)
|
||||
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1)
|
||||
mask[:, :past_key_values_length] = False
|
||||
|
||||
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
|
||||
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
||||
"""
|
||||
batch_size, source_length = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else source_length
|
||||
batch_size, src_length = mask.shape
|
||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
||||
|
||||
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor:
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`. Based on
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (batch_size * n_head, 1, max_seq_len)
|
||||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
n_head (`int`, *required*):
|
||||
num_heads (`int`, *required*):
|
||||
number of heads
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
device (`torch.device`, *optional*, default=`torch.device('cpu')`):
|
||||
device of the output alibi tensor
|
||||
"""
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
|
||||
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != n_head:
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
# This is more or less identical to T5's relative position bias:
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
# batch_size = 1, n_head = n_head, query_length
|
||||
|
||||
arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None]
|
||||
alibi = slopes.unsqueeze(-1) * arange_tensor
|
||||
alibi = alibi * attention_mask[:, None]
|
||||
return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype)
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
|
||||
|
||||
def dropout_add(x, residual, prob, training):
|
||||
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
|
||||
"""
|
||||
Dropout add function
|
||||
|
||||
Args:
|
||||
x (`torch.tensor`, *required*):
|
||||
input tensor
|
||||
residual (`torch.tensor`, *rquired*):
|
||||
residual (`torch.tensor`, *required*):
|
||||
esidual tensor
|
||||
prob (`float`, *required*):
|
||||
dropout probability
|
||||
training (`bool`, *required*):
|
||||
training mode
|
||||
"""
|
||||
out = nn.functional.dropout(x, p=prob, training=training)
|
||||
out = F.dropout(x, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def bloom_gelu_forward(x):
|
||||
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
|
||||
make the model jitable.
|
||||
@ -159,7 +158,7 @@ def bloom_gelu_forward(x):
|
||||
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
def bloom_gelu_back(g, x):
|
||||
def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
|
||||
0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@ -179,12 +178,12 @@ def bloom_gelu_back(g, x):
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||||
ctx.save_for_backward(input)
|
||||
return bloom_gelu_forward(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
||||
input = ctx.saved_tensors
|
||||
tmp = bloom_gelu_back(grad_output, input)
|
||||
return tmp
|
||||
@ -197,13 +196,12 @@ class BloomGelu(nn.Module):
|
||||
copied from Megatron-DeepSpeed code and adapted for our needs
|
||||
|
||||
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.training:
|
||||
return GeLUFunction.apply(x)
|
||||
else:
|
||||
@ -211,7 +209,7 @@ class BloomGelu(nn.Module):
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
def __init__(self, config, layer_number=None):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
@ -230,106 +228,131 @@ class BloomAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.layer_number = max(1, layer_number)
|
||||
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = 1.0
|
||||
|
||||
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
def _split_heads(self, fused_qkv):
|
||||
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim)
|
||||
"""
|
||||
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
||||
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
|
||||
# fused_qkv = fused_qkv.transpose(1, 0)
|
||||
fused_qkv = fused_qkv.reshape(new_tensor_shape)
|
||||
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
|
||||
return torch.split(fused_qkv, self.head_dim, -1)
|
||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||
storage as `fused_qkv`
|
||||
|
||||
def _merge_heads(self, x):
|
||||
Args:
|
||||
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
||||
|
||||
Returns:
|
||||
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
||||
value: [batch_size, seq_length, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
||||
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
|
||||
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
|
||||
|
||||
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Merge heads together over the last dimenstion
|
||||
|
||||
Args:
|
||||
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
||||
|
||||
Returns:
|
||||
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
||||
"""
|
||||
# What we want to achieve is:
|
||||
# batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
batch_size_and_num_heads, seq_length, _ = x.shape
|
||||
batch_size = batch_size_and_num_heads // self.num_heads
|
||||
|
||||
# First view to decompose the batch size
|
||||
# batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim
|
||||
x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim)
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
||||
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
|
||||
|
||||
# batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim
|
||||
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim
|
||||
return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim)
|
||||
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
alibi=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
|
||||
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, _, kv_length = key_layer.shape
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
beta = 1.0 / self.layer_number
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length]
|
||||
matmul_result = (1.0 / self.norm_factor) * torch.bmm(
|
||||
query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]),
|
||||
key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]),
|
||||
) + beta * alibi
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, k_length]
|
||||
attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2))
|
||||
|
||||
# We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length]
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
attn_weights = (attention_scores * self.layer_number) + attention_mask
|
||||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
|
||||
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
attention_probs = attention_probs * (~attention_mask.to(torch.bool))
|
||||
# [batch_size, num_heads, q_length, k_length]
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, k_length]
|
||||
attention_probs_reshaped = attention_probs.view(matmul_result.shape)
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(
|
||||
attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))
|
||||
)
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = context_layer.shape[-1] / self.pretraining_tp
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + nn.functional.linear(
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
@ -346,7 +369,7 @@ class BloomAttention(nn.Module):
|
||||
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
@ -357,14 +380,14 @@ class BloomMLP(nn.Module):
|
||||
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
||||
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
intermediate_output = torch.zeros_like(residual)
|
||||
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||
for i in range(self.pretraining_tp):
|
||||
intermediate_output = intermediate_output + nn.functional.linear(
|
||||
intermediate_output = intermediate_output + F.linear(
|
||||
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
@ -377,13 +400,13 @@ class BloomMLP(nn.Module):
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
def __init__(self, config, layer_number=None):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.n_head = config.n_head
|
||||
self.self_attention = BloomAttention(config, layer_number=layer_number)
|
||||
self.num_heads = config.n_head
|
||||
self.self_attention = BloomAttention(config)
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = BloomMLP(config)
|
||||
@ -393,13 +416,13 @@ class BloomBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
alibi=None,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||
|
||||
@ -462,9 +485,9 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
def _init_weights(self, module: nn.Module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, (nn.Linear)):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
@ -478,7 +501,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
|
||||
if isinstance(module, BloomModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@ -501,9 +524,8 @@ BLOOM_START_DOCSTRING = r"""
|
||||
BLOOM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
||||
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
||||
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
|
||||
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
||||
`input_ids`.
|
||||
@ -516,6 +538,10 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
||||
|
||||
Each element of `past_key_values` is a tuple (past_key, past_value):
|
||||
- past_key: [batch_size * num_heads, head_dim, kv_length]
|
||||
- past_value: [batch_size * num_heads, kv_length, head_dim]
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@ -555,19 +581,18 @@ BLOOM_INPUTS_DOCSTRING = r"""
|
||||
BLOOM_START_DOCSTRING,
|
||||
)
|
||||
class BloomModel(BloomPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.n_head = config.n_head
|
||||
self.num_heads = config.n_head
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
|
||||
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -580,25 +605,29 @@ class BloomModel(BloomPreTrainedModel):
|
||||
def get_input_embeddings(self):
|
||||
return self.word_embeddings
|
||||
|
||||
def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
|
||||
).to(attention_mask.device)
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
||||
self.word_embeddings = new_embeddings
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
||||
@ -610,17 +639,17 @@ class BloomModel(BloomPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
@ -641,10 +670,9 @@ class BloomModel(BloomPreTrainedModel):
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
@ -653,8 +681,8 @@ class BloomModel(BloomPreTrainedModel):
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_head x N x N
|
||||
# head_mask has shape n_layer x batch x n_head x N x N
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
@ -662,27 +690,28 @@ class BloomModel(BloomPreTrainedModel):
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
current_sequence_length = hidden_states.shape[1]
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[1]
|
||||
current_sequence_length += past_key_values_length
|
||||
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device)
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device)
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
@ -700,14 +729,14 @@ class BloomModel(BloomPreTrainedModel):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions, alibi)
|
||||
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
alibi,
|
||||
causal_mask,
|
||||
head_mask[i],
|
||||
)
|
||||
@ -735,8 +764,6 @@ class BloomModel(BloomPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
@ -758,7 +785,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
class BloomForCausalLM(BloomPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
self.transformer = BloomModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
@ -769,16 +796,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
def set_output_embeddings(self, new_embeddings: torch.Tensor):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
# only last token for input_ids if past is not None
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
@ -795,16 +826,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments
|
||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
@ -845,9 +876,12 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
loss = loss_fct(
|
||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
@ -862,14 +896,36 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
def _reorder_cache(
|
||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
|
||||
batch_size = len(beam_idx)
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
}
|
||||
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
(
|
||||
layer_past[0]
|
||||
.view(batch_size, num_heads, head_dim, seq_length)
|
||||
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
||||
.view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1]
|
||||
.view(batch_size, num_heads, seq_length, head_dim)
|
||||
.index_select(0, device_to_beam_idx[layer_past[0].device])
|
||||
.view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
@ -892,7 +948,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = BloomModel(config)
|
||||
@ -910,16 +966,16 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments
|
||||
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
@ -966,7 +1022,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
||||
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
@ -994,7 +1050,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
@ -1021,7 +1077,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
||||
class BloomForTokenClassification(BloomPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
@ -1047,16 +1103,16 @@ class BloomForTokenClassification(BloomPreTrainedModel):
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments
|
||||
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
||||
r"""
|
||||
@ -1095,8 +1151,11 @@ class BloomForTokenClassification(BloomPreTrainedModel):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
batch_size, seq_length = labels.shape
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = loss_fct(
|
||||
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[2:]
|
||||
|
Loading…
Reference in New Issue
Block a user