mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-28 16:52:24 +06:00
Refactor GPT2 (#11225)
* refactor GPT2 * fix mlp and head pruning * address Sylvains comments * apply suggestion from code review Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
parent
893e51a53f
commit
edca520d0f
@ -102,6 +102,8 @@ class GPT2Config(PretrainedConfig):
|
||||
and :class:`~transformers.TFGPT2DoubleHeadsModel`.
|
||||
|
||||
The dropout ratio to be used after the projection and activation.
|
||||
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Scale attention weights by dividing by sqrt(hidden_size).
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
@ -144,6 +146,7 @@ class GPT2Config(PretrainedConfig):
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
scale_attn_weights=True,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
bos_token_id=50256,
|
||||
@ -171,6 +174,7 @@ class GPT2Config(PretrainedConfig):
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
|
@ -122,37 +122,47 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
return model
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
|
||||
class GPT2Attention(nn.Module):
|
||||
def __init__(self, config, is_cross_attention=False):
|
||||
super().__init__()
|
||||
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.split_size = self.embed_dim
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale_attn_weights = config.scale_attn_weights
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
if self.is_cross_attention:
|
||||
self.c_attn = Conv1D(2 * n_state, nx)
|
||||
self.q_attn = Conv1D(n_state, nx)
|
||||
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
||||
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
||||
else:
|
||||
self.c_attn = Conv1D(3 * n_state, nx)
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
||||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
if len(heads) == 0:
|
||||
return
|
||||
heads, index = find_pruneable_heads_and_indices(
|
||||
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
||||
)
|
||||
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
||||
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
||||
|
||||
# Prune conv1d layers
|
||||
@ -160,49 +170,52 @@ class Attention(nn.Module):
|
||||
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
||||
|
||||
# Update hyper params
|
||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||
self.n_head = self.n_head - len(heads)
|
||||
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
||||
self.num_heads = self.num_heads - len(heads)
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / (float(v.size(-1)) ** 0.5)
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
w = w + attention_mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
w = self.attn_dropout(w)
|
||||
attn_weights = nn.Softmax(dim=-1)(attn_weights)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
w = w * head_mask
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
outputs = (torch.matmul(w, v),)
|
||||
if output_attentions:
|
||||
outputs += (w,)
|
||||
return outputs
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||
return attn_output, attn_weights
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden_size dim into attn_head_size and num_heads
|
||||
"""
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(*new_shape)
|
||||
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
||||
"""
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -216,65 +229,77 @@ class Attention(nn.Module):
|
||||
output_attentions=False,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
assert hasattr(
|
||||
self, "q_attn"
|
||||
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query = self.q_attn(hidden_states)
|
||||
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
past_key, past_value = layer_past
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key.transpose(-2, -1), value) # transpose to have same shapes
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
||||
a = attn_outputs[0]
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
return (a, present) + attn_outputs[1:] # a, present, (attentions)
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
||||
class GPT2MLP(nn.Module):
|
||||
def __init__(self, intermediate_size, config):
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, nx)
|
||||
self.c_proj = Conv1D(nx, n_state)
|
||||
embed_dim = config.hidden_size
|
||||
self.c_fc = Conv1D(intermediate_size, embed_dim)
|
||||
self.c_proj = Conv1D(embed_dim, intermediate_size)
|
||||
self.act = ACT2FN[config.activation_function]
|
||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.act(self.c_fc(x))
|
||||
h2 = self.c_proj(h)
|
||||
return self.dropout(h2)
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.c_proj(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
class GPT2Block(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = Attention(hidden_size, n_ctx, config, scale)
|
||||
self.attn = GPT2Attention(config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
if config.add_cross_attention:
|
||||
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
|
||||
self.crossattention = GPT2Attention(config, is_cross_attention=True)
|
||||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(inner_dim, config)
|
||||
|
||||
self.mlp = GPT2MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -287,8 +312,10 @@ class Block(nn.Module):
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
self.ln_1(hidden_states),
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
@ -298,15 +325,19 @@ class Block(nn.Module):
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + hidden_states
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# add one self-attention block for cross-attention
|
||||
assert hasattr(
|
||||
self, "crossattention"
|
||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
||||
"cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_cross_attn(hidden_states)
|
||||
cross_attn_outputs = self.crossattention(
|
||||
self.ln_cross_attn(hidden_states),
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@ -315,12 +346,14 @@ class Block(nn.Module):
|
||||
)
|
||||
attn_output = cross_attn_outputs[0]
|
||||
# residual connection
|
||||
hidden_states = hidden_states + attn_output
|
||||
hidden_states = residual + attn_output
|
||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
||||
|
||||
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
# residual connection
|
||||
hidden_states = hidden_states + feed_forward_hidden_states
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
@ -390,8 +423,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
|
||||
self-attention heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@ -539,11 +572,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@ -654,7 +690,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
# Attention mask.
|
||||
# GPT2Attention mask.
|
||||
if attention_mask is not None:
|
||||
assert batch_size > 0, "batch_size has to be defined and > 0"
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
|
Loading…
Reference in New Issue
Block a user