transformers/pytorch_pretrained_bert/modeling_openai.py
2019-01-07 12:55:36 +01:00

303 lines
10 KiB
Python

import copy
import json
import math
import re
import collections
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
ACT_FNS = {
'relu': nn.ReLU,
'swish': swish,
'gelu': gelu
}
class Conv1D(nn.Module):
def __init__(self, nf, rf, nx):
super(Conv1D, self).__init__()
self.rf = rf
self.nf = nf
if rf == 1: # faster 1x1 conv
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.w = Parameter(w)
self.b = Parameter(torch.zeros(nf))
else: # was used to train LM
raise NotImplementedError
def forward(self, x):
if self.rf == 1:
size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
x = x.view(*size_out)
else:
raise NotImplementedError
return x
class Attention(nn.Module):
def __init__(self, nx, n_ctx, cfg, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % cfg.n_head == 0
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_head = cfg.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(cfg.attn_pdrop)
self.resid_dropout = nn.Dropout(cfg.resid_pdrop)
def _attn(self, q, k, v):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
return torch.matmul(w, v)
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1)
else:
return x.permute(0, 2, 1, 3)
def forward(self, x):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
return a
class MLP(nn.Module):
def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = cfg.n_embd
self.c_fc = Conv1D(n_state, 1, nx)
self.c_proj = Conv1D(nx, 1, n_state)
self.act = ACT_FNS[cfg.afn]
self.dropout = nn.Dropout(cfg.resid_pdrop)
def forward(self, x):
h = self.act(self.c_fc(x))
h2 = self.c_proj(h)
return self.dropout(h2)
class Block(nn.Module):
def __init__(self, n_ctx, cfg, scale=False):
super(Block, self).__init__()
nx = cfg.n_embd
self.attn = Attention(nx, n_ctx, cfg, scale)
self.ln_1 = LayerNorm(nx)
self.mlp = MLP(4 * nx, cfg)
self.ln_2 = LayerNorm(nx)
def forward(self, x):
a = self.attn(x)
n = self.ln_1(x + a)
m = self.mlp(n)
h = self.ln_2(n + m)
return h
class TransformerModel(nn.Module):
""" Transformer model """
def __init__(self, cfg, vocab=40990, n_ctx=512):
super(TransformerModel, self).__init__()
self.vocab = vocab
self.embed = nn.Embedding(vocab, cfg.n_embd)
self.drop = nn.Dropout(cfg.embd_pdrop)
block = Block(n_ctx, cfg, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)])
nn.init.normal_(self.embed.weight, std=0.02)
def forward(self, x):
x = x.view(-1, x.size(-2), x.size(-1))
e = self.embed(x)
# Add the position information to the input embeddings
h = e.sum(dim=2)
for block in self.h:
h = block(h)
return h
class LMHead(nn.Module):
""" Language Model Head for the transformer """
def __init__(self, model, cfg):
super(LMHead, self).__init__()
self.n_embd = cfg.n_embd
embed_shape = model.embed.weight.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model.embed.weight # Tied weights
def forward(self, h):
# Truncated Language modeling logits (we remove the last token)
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(h_trunc)
return lm_logits
class MultipleChoiceHead(nn.Module):
""" Classifier Head for the transformer """
def __init__(self, clf_token, cfg):
super(MultipleChoiceHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, h, x):
# Classification logits
clf_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
# This double transposition is there to replicate the behavior
# of the noise_shape argument in the tensorflow
# implementation. For more details, see
# https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11
clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2)
clf_h = clf_h.contiguous().view(-1, self.n_embd)
clf_logits = self.linear(clf_h)
return clf_logits.view(-1, x.size(1))
class ClfHead(nn.Module):
"""Classification Head for the transformer
TODO: test this class."""
def __init__(self, clf_token, cfg, n_class):
super(ClfHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg.n_embd, n_class)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, h, x):
clf_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
clf_h = clf_h[flat == self.clf_token, :]
clf_h = self.dropout(clf_h)
clf_logits = self.linear(clf_h)
return clf_logits
class SimilarityHead(nn.Module):
""" Similarity Head for the transformer
TODO: test this class."""
def __init__(self, clf_token, cfg):
super(SimilarityHead, self).__init__()
self.n_embd = cfg.n_embd
self.clf_token = clf_token
self.dropout = nn.Dropout(cfg.clf_pdrop)
self.linear = nn.Linear(cfg.n_embd, 1)
nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, h, x):
sim_h = h.view(-1, self.n_embd)
flat = x[..., 0].contiguous().view(-1)
sim_h = sim_h[flat == self.clf_token, :]
sim_h = self.dropout(sim_h)
sim_h = sim_h.sum(dim = 1)
sim_logits = self.linear(sim_h)
return sim_logits
class DoubleHeadModel(nn.Module):
""" Transformer with language model and task specific heads """
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
super(DoubleHeadModel, self).__init__()
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
self.lm_head = LMHead(self.transformer, cfg)
if isinstance(task_head_type, str):
if task_head_type == 'multiple_choice':
self.task_head = MultipleChoiceHead(clf_token, cfg)
elif task_head_type == 'similarity':
self.task_head = SimilarityHead(clf_token, cfg)
elif task_head_type == 'inference':
# the three classes correspond to entailment, contradiction and neutral.
self.task_head = ClfHead(clf_token, cfg, 3)
else:
raise ValueError("task_head_type is expected to be 'multiple_choice' "
"'similarity', 'inference' or ('classification', n_class) "
f"got {task_head_type}.")
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
task_head_type[0] == 'classification':
n_class = task_head_type[1]
self.task_head = ClfHead(clf_token, cfg, n_class)
else:
raise ValueError("task_head_type is expected to be 'multiple_choice' "
"'similarity', 'inference' or ('classification', n_class) "
f"got {task_head_type}.")
def forward(self, x):
h = self.transformer(x)
lm_logits = self.lm_head(h)
task_logits = self.task_head(h, x)
return lm_logits, task_logits
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
DEFAULT_CONFIG = dotdict({
'n_embd': 768,
'n_head': 12,
'n_layer': 12,
'embd_pdrop': 0.1,
'attn_pdrop': 0.1,
'resid_pdrop': 0.1,
'afn': 'gelu',
'clf_pdrop': 0.1})