Refactor GPT2 (#11225)

* refactor GPT2

* fix mlp and head pruning

* address Sylvains comments

* apply suggestion from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Suraj Patil 2021-04-13 21:15:24 +05:30 committed by GitHub
parent 893e51a53f
commit edca520d0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 130 additions and 90 deletions

View File

@ -102,6 +102,8 @@ class GPT2Config(PretrainedConfig):
and :class:`~transformers.TFGPT2DoubleHeadsModel`. and :class:`~transformers.TFGPT2DoubleHeadsModel`.
The dropout ratio to be used after the projection and activation. The dropout ratio to be used after the projection and activation.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
@ -144,6 +146,7 @@ class GPT2Config(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
scale_attn_weights=True,
gradient_checkpointing=False, gradient_checkpointing=False,
use_cache=True, use_cache=True,
bos_token_id=50256, bos_token_id=50256,
@ -171,6 +174,7 @@ class GPT2Config(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache self.use_cache = use_cache
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id

View File

@ -122,37 +122,47 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
return model return model
class Attention(nn.Module): class GPT2Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): def __init__(self, config, is_cross_attention=False):
super().__init__() super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) max_positions = config.max_position_embeddings
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer( self.register_buffer(
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
) )
self.register_buffer("masked_bias", torch.tensor(-1e4)) self.register_buffer("masked_bias", torch.tensor(-1e4))
self.n_head = config.n_head
self.split_size = n_state self.embed_dim = config.hidden_size
self.scale = scale self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
)
self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention self.is_cross_attention = is_cross_attention
if self.is_cross_attention: if self.is_cross_attention:
self.c_attn = Conv1D(2 * n_state, nx) self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
self.q_attn = Conv1D(n_state, nx) self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else: else:
self.c_attn = Conv1D(3 * n_state, nx) self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
heads, index = find_pruneable_heads_and_indices( heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers # Prune conv1d layers
@ -160,49 +170,52 @@ class Attention(nn.Module):
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params # Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
self.n_head = self.n_head - len(heads) self.num_heads = self.num_heads - len(heads)
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): def _attn(self, query, key, value, attention_mask=None, head_mask=None):
w = torch.matmul(q, k) attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale:
w = w / (float(v.size(-1)) ** 0.5) if self.scale_attn_weights:
nd, ns = w.size(-2), w.size(-1) attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
if not self.is_cross_attention: if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
mask = self.bias[:, :, ns - nd : ns, :ns] query_length, key_length = query.size(-2), key.size(-2)
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
w = w + attention_mask attn_weights = attn_weights + attention_mask
w = nn.Softmax(dim=-1)(w) attn_weights = nn.Softmax(dim=-1)(attn_weights)
w = self.attn_dropout(w) attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
w = w * head_mask attn_weights = attn_weights * head_mask
outputs = (torch.matmul(w, v),) attn_output = torch.matmul(attn_weights, value)
if output_attentions:
outputs += (w,)
return outputs
def merge_heads(self, x): return attn_output, attn_weights
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False): def _split_heads(self, tensor, num_heads, attn_head_size):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) """
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states Splits hidden_size dim into attn_head_size and num_heads
if k: """
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
else: tensor = tensor.view(*new_shape)
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward( def forward(
self, self,
@ -216,65 +229,77 @@ class Attention(nn.Module):
output_attentions=False, output_attentions=False,
): ):
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
assert hasattr( if not hasattr(self, "q_attn"):
self, "q_attn" raise ValueError(
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." "If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states) query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask attention_mask = encoder_attention_mask
else: else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self.split_heads(query) query = self._split_heads(query, self.num_heads, self.head_dim)
key = self.split_heads(key, k=True) key = self._split_heads(key, self.num_heads, self.head_dim)
value = self.split_heads(value) value = self._split_heads(value, self.num_heads, self.head_dim)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-1) key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = (key.transpose(-2, -1), value) # transpose to have same shapes present = (key, value)
else: else:
present = None present = None
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
a = attn_outputs[0]
a = self.merge_heads(a) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
a = self.c_proj(a) attn_output = self.c_proj(attn_output)
a = self.resid_dropout(a) attn_output = self.resid_dropout(attn_output)
return (a, present) + attn_outputs[1:] # a, present, (attentions) outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
class MLP(nn.Module): class GPT2MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) def __init__(self, intermediate_size, config):
super().__init__() super().__init__()
nx = config.n_embd embed_dim = config.hidden_size
self.c_fc = Conv1D(n_state, nx) self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(nx, n_state) self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function] self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x): def forward(self, hidden_states):
h = self.act(self.c_fc(x)) hidden_states = self.c_fc(hidden_states)
h2 = self.c_proj(h) hidden_states = self.act(hidden_states)
return self.dropout(h2) hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class Block(nn.Module): class GPT2Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, config):
super().__init__() super().__init__()
hidden_size = config.n_embd hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(hidden_size, n_ctx, config, scale) self.attn = GPT2Attention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention: if config.add_cross_attention:
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) self.crossattention = GPT2Attention(config, is_cross_attention=True)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config)
self.mlp = GPT2MLP(inner_dim, config)
def forward( def forward(
self, self,
@ -287,8 +312,10 @@ class Block(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
): ):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(
self.ln_1(hidden_states), hidden_states,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
@ -298,15 +325,19 @@ class Block(nn.Module):
attn_output = attn_outputs[0] # output_attn: a, present, (attentions) attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:] outputs = attn_outputs[1:]
# residual connection # residual connection
hidden_states = attn_output + hidden_states hidden_states = attn_output + residual
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
# add one self-attention block for cross-attention # add one self-attention block for cross-attention
assert hasattr( if not hasattr(self, "crossattention"):
self, "crossattention" raise ValueError(
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention( cross_attn_outputs = self.crossattention(
self.ln_cross_attn(hidden_states), hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
@ -315,12 +346,14 @@ class Block(nn.Module):
) )
attn_output = cross_attn_outputs[0] attn_output = cross_attn_outputs[0]
# residual connection # residual connection
hidden_states = hidden_states + attn_output hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection # residual connection
hidden_states = hidden_states + feed_forward_hidden_states hidden_states = residual + feed_forward_hidden_states
if use_cache: if use_cache:
outputs = (hidden_states,) + outputs outputs = (hidden_states,) + outputs
@ -390,8 +423,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`. sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
heads. self-attention heads.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@ -539,11 +572,14 @@ class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.embed_dim = config.hidden_size
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.init_weights() self.init_weights()
@ -654,7 +690,7 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# Attention mask. # GPT2Attention mask.
if attention_mask is not None: if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0" assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask.view(batch_size, -1)