mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
extract attention weights from GPT
This commit is contained in:
parent
db98a4a48b
commit
e211785ada
@ -253,7 +253,7 @@ class Conv1D(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
@ -262,6 +262,7 @@ class Attention(nn.Module):
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
self.output_attentions = output_attentions
|
||||
self.c_attn = Conv1D(n_state * 3, 1, nx)
|
||||
self.c_proj = Conv1D(n_state, 1, nx)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
@ -278,6 +279,8 @@ class Attention(nn.Module):
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
w = self.attn_dropout(w)
|
||||
if self.output_attentions:
|
||||
return w, torch.matmul(w, v)
|
||||
return torch.matmul(w, v)
|
||||
|
||||
def merge_heads(self, x):
|
||||
@ -300,9 +303,13 @@ class Attention(nn.Module):
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
a = self._attn(query, key, value)
|
||||
if self.output_attentions:
|
||||
attentions, a = a
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
if self.output_attentions:
|
||||
return attentions, a
|
||||
return a
|
||||
|
||||
|
||||
@ -322,19 +329,24 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
def __init__(self, n_ctx, config, scale=False, output_attentions=False):
|
||||
super(Block, self).__init__()
|
||||
nx = config.n_embd
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.output_attentions = output_attentions
|
||||
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
|
||||
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.attn(x)
|
||||
if self.output_attentions:
|
||||
attentions, a = a
|
||||
n = self.ln_1(x + a)
|
||||
m = self.mlp(n)
|
||||
h = self.ln_2(n + m)
|
||||
if self.output_attentions:
|
||||
return attentions, h
|
||||
return h
|
||||
|
||||
|
||||
@ -591,12 +603,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
|
||||
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
|
||||
self.apply(self.init_weights)
|
||||
@ -639,9 +652,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
# Add the position information to the input embeddings
|
||||
# h = e.sum(dim=2)
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
all_attentions = []
|
||||
for block in self.h:
|
||||
hidden_states = block(hidden_states)
|
||||
if self.output_attentions:
|
||||
attentions, hidden_states = block(hidden_states)
|
||||
all_attentions.append(attentions)
|
||||
else:
|
||||
hidden_states = block(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
if self.output_attentions:
|
||||
return all_attentions, hidden_states.view(*output_shape)
|
||||
return hidden_states.view(*output_shape)
|
||||
|
||||
|
||||
@ -701,9 +721,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
@ -716,6 +736,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
if self.transformer.output_attentions:
|
||||
all_attentions, hidden_states = hidden_states
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
@ -726,6 +748,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
return loss
|
||||
if self.transformer.output_attentions:
|
||||
return all_attentions, lm_logits
|
||||
return lm_logits
|
||||
|
||||
|
||||
@ -790,9 +814,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
@ -806,6 +830,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
if self.transformer.output_attentions:
|
||||
all_attentions, hidden_states = hidden_states
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
@ -819,4 +845,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||
if losses:
|
||||
return losses
|
||||
if self.transformer.output_attentions:
|
||||
return all_attentions, lm_logits, mc_logits
|
||||
return lm_logits, mc_logits
|
||||
|
Loading…
Reference in New Issue
Block a user