extract attention weights from GPT

This commit is contained in:
thomwolf 2019-05-02 18:31:26 +02:00
parent db98a4a48b
commit e211785ada

View File

@ -253,7 +253,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
@ -262,6 +262,7 @@ class Attention(nn.Module):
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
@ -278,6 +279,8 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v)
def merge_heads(self, x):
@ -300,9 +303,13 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True)
value = self.split_heads(value)
a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a
return a
@ -322,19 +329,24 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__()
nx = config.n_embd
self.attn = Attention(nx, n_ctx, config, scale)
self.output_attentions = output_attentions
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x):
a = self.attn(x)
if self.output_attentions:
attentions, a = a
n = self.ln_1(x + a)
m = self.mlp(n)
h = self.ln_2(n + m)
if self.output_attentions:
return attentions, h
return h
@ -591,12 +603,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(OpenAIGPTModel, self).__init__(config)
self.output_attentions = output_attentions
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights)
@ -639,9 +652,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Add the position information to the input embeddings
# h = e.sum(dim=2)
hidden_states = inputs_embeds + position_embeds + token_type_embeds
all_attentions = []
for block in self.h:
hidden_states = block(hidden_states)
if self.output_attentions:
attentions, hidden_states = block(hidden_states)
all_attentions.append(attentions)
else:
hidden_states = block(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape)
return hidden_states.view(*output_shape)
@ -701,9 +721,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights)
@ -716,6 +736,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
@ -726,6 +748,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits
return lm_logits
@ -790,9 +814,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights)
@ -806,6 +830,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
@ -819,4 +845,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses:
return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits
return lm_logits, mc_logits