[WIP] GPT Neo cleanup (#10985)

* better names

* add attention mixin

* all slow tests in one class

* make helper methods static so we can test

* add local attention tests

* better names

* doc

* apply review suggestions
This commit is contained in:
Suraj Patil 2021-04-06 21:54:15 +05:30 committed by GitHub
parent 76800fb8e6
commit 2a8115f083
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 392 additions and 216 deletions

View File

@ -130,7 +130,130 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
return model
class GPTNeoSelfAttention(nn.Module):
class GPTNeoAttentionMixin:
"""
A few attention related utilities for attention modules in GPT Neo, to be used as a mixin.
"""
@staticmethod
def _get_block_length_and_num_blocks(seq_length, window_size):
"""
Computes ``block_length`` and ``num_blocks`` such that ``seq_length`` becomes evenly divisible by
``block_length``.
"""
block_length = window_size
while seq_length % block_length != 0:
block_length -= 1
num_blocks = seq_length // block_length
return block_length, num_blocks
@staticmethod
def _look_back(tensor, block_length, window_size, pad_value=0, is_key_value=True):
"""
Used to implement attention between consecutive blocks. This method assumes that dim 1 of :obj:`tensor`
represents the :obj:`seq_length` dimention. It splits :obj:`seq_length` dimention into :obj:`num_blocks` and
:obj:`window_size` + :obj:`block_length`. It pads the :obj:`seq_length` dimention if necessary.
Example::
tensor: torch.tensor([[[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]])
with shape (1, 8, 1)
block_length = window_size = 4
_look_back =>
torch.tensor([[[[ 0.0000], [ 0.0000], [ 0.0000], [ 0.0000], [ 0.4983], [ 2.6918], [-0.0071], [ 1.0492]],
[[ 0.4983], [ 2.6918], [-0.0071], [ 1.0492], [-1.8348], [ 0.7672], [ 0.2986], [ 0.0285]]]])
Args:
tensor (:obj:`torch.Tensor`): tensor of shape :obj:`[batch_size, seq_length, hidden_dim]` or :obj:`[batch_size, seq_length]`
block_length (:obj:`int`): An integer specifying the length of each block, used as a step size when creating the blocks.
window_size (:obj:`int`): An integer specifying the size of attention window, used to calculate the final block size when creating the block.
pad_value (obj:`int`): An integer specifying the value to use when padding the :obj:`tensor`.
is_key_value (:obj:`bool`): A boolean indicating if the :obj:`tensor` is a key/value tensor.
Returns:
tensor of shape :obj:`[batch_size, num_blocks, window_size + block_length, ...]` if :obj:`is_key_value` is
:obj:`True` else a tensor of shape :obj:`[batch_size, window_size + block_length, num_blocks, ...]`
"""
if len(tensor.shape) == 3:
padding_side = (0, 0, window_size, 0)
elif len(tensor.shape) == 2:
padding_side = (window_size, 0)
else:
raise ValueError(f"Input tensor rank should be one of [2, 3], but is: {len(tensor.shape)}")
padded_tensor = F.pad(tensor, padding_side, value=pad_value)
padded_tensor = padded_tensor.unfold(dimension=1, size=window_size + block_length, step=block_length)
if is_key_value:
padded_tensor = padded_tensor.transpose(-2, -1)
return padded_tensor
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _split_seq_length_dim_to(self, tensors, dim_factor_1, dim_factor_2, hidden_size):
"""
Splits sequence length dim of tensors into `dim_factor_1` and `dim_factor_2` dims
"""
batch_size = tensors.shape[0]
split_dim_shape = (batch_size, dim_factor_1, dim_factor_2)
if len(tensors.shape) == 3:
return torch.reshape(tensors, split_dim_shape + (hidden_size,))
elif len(tensors.shape) == 2:
return torch.reshape(tensors, split_dim_shape)
else:
raise ValueError(f"Input vector rank should be one of [2, 3], but is: {len(tensors.shape)}")
def _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask=None, head_mask=None):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_weights = attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class GPTNeoSelfAttention(nn.Module, GPTNeoAttentionMixin):
def __init__(self, config):
super().__init__()
@ -149,56 +272,16 @@ class GPTNeoSelfAttention(nn.Module):
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
# Keep the attention weights computation in fp32 to avoid overflow issues
q = q.to(torch.float32)
k = k.to(torch.float32)
attn_weights = torch.matmul(q, k)
nd, ns = attn_weights.size(-2), attn_weights.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
attn_weights = torch.where(mask.bool(), attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(v.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
outputs = (torch.matmul(attn_weights, v),)
if output_attentions:
outputs += (attn_weights,)
return outputs
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(
self,
hidden_states,
@ -213,31 +296,40 @@ class GPTNeoSelfAttention(nn.Module):
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key.transpose(-2, -1), value) # transpose to have same shapes
present = (key, value)
else:
present = None
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
a = attn_outputs[0]
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
a = self.merge_heads(a)
a = self.out_proj(a)
a = self.resid_dropout(a)
attn_output, attn_weights = self._attn(
query, key, value, causal_mask, self.masked_bias, self.attn_dropout, attention_mask, head_mask
)
return (a, present) + attn_outputs[1:] # a, present, (attentions)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class GPTNeoLocalSelfAttention(nn.Module):
class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
def __init__(self, config):
super().__init__()
@ -249,9 +341,10 @@ class GPTNeoLocalSelfAttention(nn.Module):
self.embed_dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
@ -260,94 +353,39 @@ class GPTNeoLocalSelfAttention(nn.Module):
self.window_size = config.window_size
def shift(self, x, offset, pad_value=0, dim=2):
t = x.shape[1]
dims = (len(x.shape) - dim) * (0, 0)
padded_x = F.pad(x, (*dims, offset, 0), value=pad_value)
return padded_x[:, :t, ...]
def _create_attention_mask(self, batch_size, seq_length, num_blocks, block_length, device, attention_mask=None):
indices = torch.arange(seq_length, dtype=torch.long, device=device).repeat(batch_size, 1)
def look_around(self, x, block_length, window_size):
num_complete_blocks = window_size // block_length
query_indices = self._split_seq_length_dim_to(indices, num_blocks, block_length, self.embed_dim)
key_indices = self._look_back(indices, block_length, self.window_size, is_key_value=False)
parts = [x]
for i in range(1, num_complete_blocks + 1):
parts = [self.shift(x, i)] + parts
# create mask tensor such that each block contains a causal_mask for that block
causal_mask = torch.ge(query_indices.unsqueeze(-1), key_indices.unsqueeze(-2))
partial_size = window_size % block_length
if partial_size > 0:
margin = x[:, :, block_length - partial_size : block_length, ...]
parts = [self.shift(margin, num_complete_blocks + 1)] + parts
return torch.cat(parts, dim=2)
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=device)
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.num_heads, x.size(-1) // self.num_heads)
x = x.view(*new_x_shape)
if k:
return x.permute(0, 1, 3, 4, 2) # (batch, chunks, head, head_features, seq_length)
else:
return x.permute(0, 1, 3, 2, 4) # (batch, chunks, head, seq_length, head_features)
# A block can also be padded becuase of the _look_back operation
# look back into the attention_block such that it will also get padded the same way
# and have 0s in the padded position
attention_mask = self._look_back(attention_mask, block_length, self.window_size, is_key_value=False)
attention_mask = attention_mask.unsqueeze(-2) # Add an extra dimention to account for hidden_dim
def merge_heads(self, x):
x = x.permute(0, 1, 3, 2, 4).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape)
# Multiply the causal_mask with attention_mask so the padded positions (by _look_back operation)
# will contain 0s.
# This also makes sure that other positions ignored by the attention_mask will also be ignored
# in the causal_mask.
causal_mask = causal_mask * attention_mask
def _split_seq_length_dim_to(self, tensors, num_blocks, block_length):
return tensors.reshape(tensors.size()[0], num_blocks, block_length, -1)
# In GPT Neo's local attention each window can attend to at most window_size tokens
# rest of the tokens should be ignored.
relative_position = key_indices.unsqueeze(-2) - query_indices.unsqueeze(-1)
visible = torch.gt(relative_position, -self.window_size)
def create_attention_mask(self, bs, seq_len, windows, block_length, attention_mask):
ticker = torch.arange(seq_len)[None, :]
b_t = ticker.reshape(1, windows, block_length)
causal_mask = causal_mask * visible
causal_mask = causal_mask.unsqueeze(-3).bool() # Add an extra dimention to account for num_heads
bq_t = b_t
bq_k = self.look_around(b_t, block_length, self.window_size)
# compute attn mask
# this matches the original implem in mess-tensorflow
# https://github.com/tensorflow/mesh/blob/8bd599a21bad01cef1300a8735c17306ce35db6e/mesh_tensorflow/transformer/attention.py#L805
relative_position = bq_k.unsqueeze(-2) - bq_t.unsqueeze(-1)
relative_position = relative_position.transpose(-1, -2)
sequence_id = torch.ones(bs, seq_len)
q_seq = sequence_id.reshape(-1, windows, block_length)
m_seq = sequence_id.reshape(-1, windows, block_length)
m_seq = self.look_around(m_seq, block_length, self.window_size)
if attention_mask is not None:
attention_mask = attention_mask.to(m_seq.device)
attention_mask = attention_mask.reshape(-1, windows, block_length)
attention_mask = self.look_around(attention_mask, block_length, self.window_size)
m_seq *= attention_mask
visible = torch.eq(q_seq.unsqueeze(-1), m_seq.unsqueeze(-2)).transpose(-1, -2)
visible = torch.logical_and(visible, torch.gt(relative_position, -self.window_size))
mask = torch.logical_and(visible, torch.less_equal(relative_position, 0)).transpose(-1, -2).unsqueeze(2)
return mask
def _attn(self, q, k, v, causal_mask, head_mask=None, output_attentions=False):
# attn
# Keep the attention weights computation in fp32 to avoid overflow issues
q = q.to(torch.float32)
k = k.to(torch.float32)
attn_weights = torch.matmul(q, k)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(v.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, v)
outputs = (attn_output,)
if output_attentions:
outputs += (attn_weights,)
return outputs
return causal_mask
def forward(
self,
@ -371,51 +409,58 @@ class GPTNeoLocalSelfAttention(nn.Module):
key = self.k_proj(key_value_hidden_states)
value = self.v_proj(key_value_hidden_states)
# compute block length and windows
bs, seq_len = hidden_states.shape[:2]
full_seq_length = seq_len + past_length
block_length = self.window_size
while full_seq_length % block_length != 0:
block_length -= 1
num_blocks = full_seq_length // block_length
# compute block length and num_blocks
batch_size, seq_length = hidden_states.shape[:2]
full_seq_length = seq_length + past_length
block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size)
# create buckets
if layer_past is not None:
# we just need 1 window with block_length 1 when caching is enabled
query = self._split_seq_length_dim_to(query, 1, 1)
# we just need 1 block with block_length 1 when caching is enabled
query = self._split_seq_length_dim_to(query, 1, 1, self.embed_dim)
else:
query = self._split_seq_length_dim_to(query, num_blocks, block_length)
query = self._split_seq_length_dim_to(query, num_blocks, block_length, self.embed_dim)
key = self._split_seq_length_dim_to(key, num_blocks, block_length)
value = self._split_seq_length_dim_to(value, num_blocks, block_length)
key = self._look_back(key, block_length, self.window_size)
value = self._look_back(value, block_length, self.window_size)
key = self.look_around(key, block_length, self.window_size)
value = self.look_around(value, block_length, self.window_size)
# select key/value vectors only for the last window
# select key/value vectors only for the last block
if layer_past is not None:
key = key[:, -1:, ...]
value = value[:, -1:, ...]
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
mask = self.create_attention_mask(bs, full_seq_length, num_blocks, block_length, attention_mask)
mask = self._create_attention_mask(
batch_size, full_seq_length, num_blocks, block_length, hidden_states.device, attention_mask
)
if layer_past is not None:
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last window
mask = mask.to(hidden_states.device)
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last block
# attn
attn_outputs = self._attn(query, key, value, mask, head_mask, output_attentions)
attn = attn_outputs[0]
attn_output, attn_weights = self._attn(
query,
key,
value,
causal_mask=mask,
masked_bias=self.masked_bias,
attn_dropout=self.attn_dropout,
head_mask=head_mask,
)
attn = self.merge_heads(attn)
attn = attn.reshape(bs, seq_len, self.embed_dim)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim)
attn = self.out_proj(attn)
attn = self.resid_dropout(attn)
return (attn,) + attn_outputs[1:]
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output,)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, (attentions)
class GPTNeoAttention(nn.Module):
@ -464,7 +509,7 @@ class GPTNeoAttention(nn.Module):
return outputs
class MLP(nn.Module):
class GPTNeoMLP(nn.Module):
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * hidden_size
super().__init__()
embed_dim = config.hidden_size
@ -473,13 +518,15 @@ class MLP(nn.Module):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_dropout)
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class Block(nn.Module):
class GPTNeoBlock(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
hidden_size = config.hidden_size
@ -487,7 +534,7 @@ class Block(nn.Module):
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTNeoAttention(config, layer_id)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config)
self.mlp = GPTNeoMLP(inner_dim, config)
def forward(
self,
@ -498,8 +545,10 @@ class Block(nn.Module):
use_cache=False,
output_attentions=False,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
self.ln_1(hidden_states),
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
@ -509,11 +558,13 @@ class Block(nn.Module):
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + hidden_states
hidden_states = attn_output + residual
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
@ -638,7 +689,7 @@ GPT_NEO_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare GPTNeo Model transformer outputting raw hidden-states without any specific head on top.",
"The bare GPT Neo Model transformer outputting raw hidden-states without any specific head on top.",
GPT_NEO_START_DOCSTRING,
)
class GPTNeoModel(GPTNeoPreTrainedModel):
@ -649,7 +700,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embed_dropout)
self.h = nn.ModuleList([Block(config, layer_id=i) for i in range(config.num_layers)])
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.init_weights()

View File

@ -18,6 +18,7 @@
import unittest
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
@ -35,6 +36,7 @@ if is_torch_available():
GPTNeoForCausalLM,
GPTNeoModel,
)
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin, GPTNeoLocalSelfAttention
class GPTNeoModelTester:
@ -430,11 +432,164 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
# check attn size
self.assertListEqual(shapes, expected_shape)
@require_torch
class GPTNeoLocalAttentionTest(unittest.TestCase):
def _get_hidden_states(self):
return torch.tensor(
[
[
[0.4983, -0.7584, -1.6944, 0.5440],
[2.6918, 0.4206, 0.4176, 0.2055],
[-0.0071, -0.0405, -1.4920, -0.3630],
[1.0492, 0.1599, -1.7648, 0.2419],
[-1.8348, 2.0514, -0.1946, 0.3203],
[0.7672, -1.1600, -1.7118, -0.9056],
[0.2986, 0.5372, 0.7729, -0.1927],
[0.0285, 0.2629, -1.1156, -1.1992],
]
],
dtype=torch.float32,
device=torch_device,
)
def test_look_back(self):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
# check when seq_length is divisible by window_size
window_size = 4
block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size)
expected_shape = [batch_size, num_block, window_size + block_length, hidden_size]
self.assertListEqual(list(blocked_hidden_states.shape), expected_shape)
# The last block should contain the last (window_size + block_length) hidden_states
self.assertTrue(
torch.all(blocked_hidden_states[:, -1, ...] == hidden_states[:, -(window_size + block_length) :, ...])
)
# check when seq_length is not divisible by window_size
window_size = 3
block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size)
expected_shape = [batch_size, num_block, window_size + block_length, hidden_size]
self.assertListEqual(list(blocked_hidden_states.shape), expected_shape)
# The last block should contain the last (window_size + block_length) hidden_states
self.assertTrue(
torch.all(blocked_hidden_states[:, -1, ...] == hidden_states[:, -(window_size + block_length) :, ...])
)
# check when window_size is > seq_length
window_size = 19
block_length, num_block = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
blocked_hidden_states = GPTNeoAttentionMixin._look_back(hidden_states, block_length, window_size)
expected_shape = [batch_size, num_block, window_size + block_length, hidden_size]
self.assertListEqual(list(blocked_hidden_states.shape), expected_shape)
# when window_size > seq_length, num_blocks becomes 1, in this case
# the first window_size values in blocked_hidden_staes are all zeros
# and the last block_length values are equal to the hidden_states
values = blocked_hidden_states[:, -1, :window_size, ...]
expected_values = torch.zeros_like(values)
self.assertTrue(torch.all(values == expected_values))
self.assertTrue(torch.all(blocked_hidden_states[:, -1, -block_length:, ...] == hidden_states))
def test_create_attention_mask(self):
config = GPTNeoConfig.from_pretrained("valhalla/gpt-neo-random-tiny")
layer = GPTNeoLocalSelfAttention(config)
window_size = config.window_size
batch_size, seq_length = 8, 1
block_length, num_blocks = GPTNeoAttentionMixin._get_block_length_and_num_blocks(seq_length, window_size)
causal_mask = layer._create_attention_mask(batch_size, seq_length, num_blocks, block_length, torch_device)
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
# first window_size tokens in the first block are always padded
# and should not be attended
self.assertTrue(torch.all(causal_mask[:, 0, :, :, :window_size] == 0))
# each window can attend at most window_size tokens
self.assertTrue(torch.all(torch.sum(causal_mask, dim=4) <= config.window_size))
# check if user provided attention_mask is handled correctly
attention_mask = torch.ones(batch_size, seq_length, dtype=torch.long, device=torch_device)
attention_mask[:, -3:] = 0 # don't attend last 3 tokens
causal_mask = layer._create_attention_mask(
batch_size, seq_length, num_blocks, block_length, torch_device, attention_mask
)
# last 3 tokens will be in the last block and shoul have 0s in causal_mask
self.assertTrue(torch.all(causal_mask[:, -1, :, :, -3:] == 0))
# check shapes
expected_shape = [batch_size, num_blocks, 1, block_length, window_size + block_length]
self.assertListEqual(list(causal_mask.shape), expected_shape)
# first window_size tokens in the first block are always padded
# and should not be attended
self.assertTrue(torch.all(causal_mask[:, 0, :, :, :window_size] == 0))
# each window can attend at most window_size tokens
self.assertTrue(torch.all(torch.sum(causal_mask, dim=4) <= config.window_size))
def test_local_attn_probs(self):
model = GPTNeoModel.from_pretrained("valhalla/gpt-neo-random-tiny").eval()
layer = model.h[1].attn.attention.to(torch_device)
hidden_states = self._get_hidden_states()
hidden_states = torch.cat([hidden_states, hidden_states - 0.5], dim=2)
batch_size, seq_length, hidden_size = hidden_states.shape
mask_tokens = 3
attention_mask = torch.ones(batch_size, seq_length, device=torch_device, dtype=torch.long)
attention_mask[:, -mask_tokens:] = 0 # dont atten last mask_tokens
_, attn_probs = layer(hidden_states, attention_mask=attention_mask, output_attentions=True)
# the last 3 tokens will be in the last block, and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, -1, :, -mask_tokens:, -mask_tokens:] == 0))
# the first config.window_size tokens in the first block are always padded
# and should have 0 attn_probs
self.assertTrue(torch.all(attn_probs[:, 0, :, : model.config.window_size :, : model.config.window_size] == 0))
@require_torch
class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
@cached_property
def model(self):
return GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B").to(torch_device)
@cached_property
def tokenizer(self):
return GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
@slow
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = self.model
model.config.gradient_checkpointing = checkpointing
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
# The dog-eared copy of the book, which is a collection of essays by the late author,
expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11]
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_gpt_neo_sample(self):
model = self.model
tokenizer = self.tokenizer
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and if you dont get the memo here is what you can"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
@slow
def test_batch_generation(self):
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
model.to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = self.model
tokenizer = self.tokenizer
tokenizer.padding_side = "left"
@ -479,33 +634,3 @@ class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase
for model_name in GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = GPTNeoModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class GPTNeoModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_gpt_neo(self):
for checkpointing in [True, False]:
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B", gradient_checkpointing=checkpointing)
model.to(torch_device)
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
# fmt: off
expected_output_ids = [464, 3290, 12, 3380, 4866, 286, 262, 1492, 11, 543, 318, 257, 4947, 286, 27126, 416, 262, 2739, 1772, 11] # The dog-eared copy of the book, which is a collection of essays by the late author,
# fmt: on
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
@slow
def test_gpt_neo_sample(self):
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
model.to(torch_device)
torch.manual_seed(0)
tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and if you dont get the memo here is what you can"
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)