From edca520d0fd8f23e6a5dbf98c209f8da0e3e293c Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 13 Apr 2021 21:15:24 +0530 Subject: [PATCH] Refactor GPT2 (#11225) * refactor GPT2 * fix mlp and head pruning * address Sylvains comments * apply suggestion from code review Co-authored-by: Lysandre Debut --- .../models/gpt2/configuration_gpt2.py | 4 + src/transformers/models/gpt2/modeling_gpt2.py | 216 ++++++++++-------- 2 files changed, 130 insertions(+), 90 deletions(-) diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index 5c69e9dfe5f..00d7b88a4ff 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -102,6 +102,8 @@ class GPT2Config(PretrainedConfig): and :class:`~transformers.TFGPT2DoubleHeadsModel`. The dropout ratio to be used after the projection and activation. + scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): + Scale attention weights by dividing by sqrt(hidden_size). gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -144,6 +146,7 @@ class GPT2Config(PretrainedConfig): summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, + scale_attn_weights=True, gradient_checkpointing=False, use_cache=True, bos_token_id=50256, @@ -171,6 +174,7 @@ class GPT2Config(PretrainedConfig): self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache self.bos_token_id = bos_token_id diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index a7c4ba1277e..d78e6433050 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -122,37 +122,47 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model -class Attention(nn.Module): - def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False): +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False): super().__init__() - n_state = nx # in Attention: n_state=768 (nx=n_embd) - # [switch nx => n_state from Block to Attention to keep identical to TF implem] - assert n_state % config.n_head == 0 + max_positions = config.max_position_embeddings self.register_buffer( - "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) - self.n_head = config.n_head - self.split_size = n_state - self.scale = scale + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention + if self.is_cross_attention: - self.c_attn = Conv1D(2 * n_state, nx) - self.q_attn = Conv1D(n_state, nx) + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) else: - self.c_attn = Conv1D(3 * n_state, nx) - self.c_proj = Conv1D(n_state, nx) + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices( - heads, self.n_head, self.split_size // self.n_head, self.pruned_heads - ) + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) # Prune conv1d layers @@ -160,49 +170,52 @@ class Attention(nn.Module): self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) # Update hyper params - self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) - self.n_head = self.n_head - len(heads) + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): - w = torch.matmul(q, k) - if self.scale: - w = w / (float(v.size(-1)) ** 0.5) - nd, ns = w.size(-2), w.size(-1) + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) if not self.is_cross_attention: # if only "normal" attention layer implements causal mask - mask = self.bias[:, :, ns - nd : ns, :ns] - w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: # Apply the attention mask - w = w + attention_mask + attn_weights = attn_weights + attention_mask - w = nn.Softmax(dim=-1)(w) - w = self.attn_dropout(w) + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: - w = w * head_mask + attn_weights = attn_weights * head_mask - outputs = (torch.matmul(w, v),) - if output_attentions: - outputs += (w,) - return outputs + attn_output = torch.matmul(attn_weights, value) - def merge_heads(self, x): - x = x.permute(0, 2, 1, 3).contiguous() - new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) - return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states + return attn_output, attn_weights - def split_heads(self, x, k=False): - new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) - x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states - if k: - return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) - else: - return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(*new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) def forward( self, @@ -216,65 +229,77 @@ class Attention(nn.Module): output_attentions=False, ): if encoder_hidden_states is not None: - assert hasattr( - self, "q_attn" - ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`." + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self.split_heads(query) - key = self.split_heads(key, k=True) - value = self.split_heads(value) + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + if layer_past is not None: - past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below - key = torch.cat((past_key, key), dim=-1) + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = (key.transpose(-2, -1), value) # transpose to have same shapes + present = (key, value) else: present = None - attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions) - a = attn_outputs[0] + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - a = self.merge_heads(a) - a = self.c_proj(a) - a = self.resid_dropout(a) + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) - return (a, present) + attn_outputs[1:] # a, present, (attentions) + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) -class MLP(nn.Module): - def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): super().__init__() - nx = config.n_embd - self.c_fc = Conv1D(n_state, nx) - self.c_proj = Conv1D(nx, n_state) + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) - def forward(self, x): - h = self.act(self.c_fc(x)) - h2 = self.c_proj(h) - return self.dropout(h2) + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states -class Block(nn.Module): - def __init__(self, n_ctx, config, scale=False): +class GPT2Block(nn.Module): + def __init__(self, config): super().__init__() - hidden_size = config.n_embd + hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = Attention(hidden_size, n_ctx, config, scale) + self.attn = GPT2Attention(config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + if config.add_cross_attention: - self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True) + self.crossattention = GPT2Attention(config, is_cross_attention=True) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = MLP(inner_dim, config) + + self.mlp = GPT2MLP(inner_dim, config) def forward( self, @@ -287,8 +312,10 @@ class Block(nn.Module): use_cache=False, output_attentions=False, ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( - self.ln_1(hidden_states), + hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask, @@ -298,15 +325,19 @@ class Block(nn.Module): attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + hidden_states + hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention - assert hasattr( - self, "crossattention" - ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( - self.ln_cross_attn(hidden_states), + hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, @@ -315,12 +346,14 @@ class Block(nn.Module): ) attn_output = cross_attn_outputs[0] # residual connection - hidden_states = hidden_states + attn_output + hidden_states = residual + attn_output outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights - feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = hidden_states + feed_forward_hidden_states + hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs @@ -390,8 +423,8 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. """ loss: Optional[torch.FloatTensor] = None @@ -539,11 +572,14 @@ class GPT2Model(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) - self.wte = nn.Embedding(config.vocab_size, config.n_embd) - self.wpe = nn.Embedding(config.n_positions, config.n_embd) + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) - self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) - self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.init_weights() @@ -654,7 +690,7 @@ class GPT2Model(GPT2PreTrainedModel): position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - # Attention mask. + # GPT2Attention mask. if attention_mask is not None: assert batch_size > 0, "batch_size has to be defined and > 0" attention_mask = attention_mask.view(batch_size, -1)