import copy import json import math import re import collections import numpy as np import torch import torch.nn as nn from torch.nn.parameter import Parameter from .modeling import BertLayerNorm as LayerNorm def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) def swish(x): return x * torch.sigmoid(x) ACT_FNS = { 'relu': nn.ReLU, 'swish': swish, 'gelu': gelu } class Conv1D(nn.Module): def __init__(self, nf, rf, nx): super(Conv1D, self).__init__() self.rf = rf self.nf = nf if rf == 1: # faster 1x1 conv w = torch.empty(nx, nf) nn.init.normal_(w, std=0.02) self.w = Parameter(w) self.b = Parameter(torch.zeros(nf)) else: # was used to train LM raise NotImplementedError def forward(self, x): if self.rf == 1: size_out = x.size()[:-1] + (self.nf,) x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w) x = x.view(*size_out) else: raise NotImplementedError return x class Attention(nn.Module): def __init__(self, nx, n_ctx, cfg, scale=False): super(Attention, self).__init__() n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % cfg.n_head == 0 self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.n_head = cfg.n_head self.split_size = n_state self.scale = scale self.c_attn = Conv1D(n_state * 3, 1, nx) self.c_proj = Conv1D(n_state, 1, nx) self.attn_dropout = nn.Dropout(cfg.attn_pdrop) self.resid_dropout = nn.Dropout(cfg.resid_pdrop) def _attn(self, q, k, v): w = torch.matmul(q, k) if self.scale: w = w / math.sqrt(v.size(-1)) w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights w = nn.Softmax(dim=-1)(w) w = self.attn_dropout(w) return torch.matmul(w, v) def merge_heads(self, x): x = x.permute(0, 2, 1, 3).contiguous() new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states def split_heads(self, x, k=False): new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states if k: return x.permute(0, 2, 3, 1) else: return x.permute(0, 2, 1, 3) def forward(self, x): x = self.c_attn(x) query, key, value = x.split(self.split_size, dim=2) query = self.split_heads(query) key = self.split_heads(key, k=True) value = self.split_heads(value) a = self._attn(query, key, value) a = self.merge_heads(a) a = self.c_proj(a) a = self.resid_dropout(a) return a class MLP(nn.Module): def __init__(self, n_state, cfg): # in MLP: n_state=3072 (4 * n_embd) super(MLP, self).__init__() nx = cfg.n_embd self.c_fc = Conv1D(n_state, 1, nx) self.c_proj = Conv1D(nx, 1, n_state) self.act = ACT_FNS[cfg.afn] self.dropout = nn.Dropout(cfg.resid_pdrop) def forward(self, x): h = self.act(self.c_fc(x)) h2 = self.c_proj(h) return self.dropout(h2) class Block(nn.Module): def __init__(self, n_ctx, cfg, scale=False): super(Block, self).__init__() nx = cfg.n_embd self.attn = Attention(nx, n_ctx, cfg, scale) self.ln_1 = LayerNorm(nx) self.mlp = MLP(4 * nx, cfg) self.ln_2 = LayerNorm(nx) def forward(self, x): a = self.attn(x) n = self.ln_1(x + a) m = self.mlp(n) h = self.ln_2(n + m) return h class TransformerModel(nn.Module): """ Transformer model """ def __init__(self, cfg, vocab=40990, n_ctx=512): super(TransformerModel, self).__init__() self.vocab = vocab self.embed = nn.Embedding(vocab, cfg.n_embd) self.drop = nn.Dropout(cfg.embd_pdrop) block = Block(n_ctx, cfg, scale=True) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.n_layer)]) nn.init.normal_(self.embed.weight, std=0.02) def forward(self, x): x = x.view(-1, x.size(-2), x.size(-1)) e = self.embed(x) # Add the position information to the input embeddings h = e.sum(dim=2) for block in self.h: h = block(h) return h class LMHead(nn.Module): """ Language Model Head for the transformer """ def __init__(self, model, cfg): super(LMHead, self).__init__() self.n_embd = cfg.n_embd embed_shape = model.embed.weight.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder.weight = model.embed.weight # Tied weights def forward(self, h): # Truncated Language modeling logits (we remove the last token) h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) lm_logits = self.decoder(h_trunc) return lm_logits class MultipleChoiceHead(nn.Module): """ Classifier Head for the transformer """ def __init__(self, clf_token, cfg): super(MultipleChoiceHead, self).__init__() self.n_embd = cfg.n_embd self.clf_token = clf_token self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation self.linear = nn.Linear(cfg.n_embd, 1) nn.init.normal_(self.linear.weight, std = 0.02) nn.init.normal_(self.linear.bias, 0) def forward(self, h, x): # Classification logits clf_h = h.view(-1, self.n_embd) flat = x[..., 0].contiguous().view(-1) clf_h = clf_h[flat == self.clf_token, :] clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1) # This double transposition is there to replicate the behavior # of the noise_shape argument in the tensorflow # implementation. For more details, see # https://github.com/huggingface/pytorch-openai-transformer-lm/issues/11 clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2) clf_h = clf_h.contiguous().view(-1, self.n_embd) clf_logits = self.linear(clf_h) return clf_logits.view(-1, x.size(1)) class ClfHead(nn.Module): """Classification Head for the transformer TODO: test this class.""" def __init__(self, clf_token, cfg, n_class): super(ClfHead, self).__init__() self.n_embd = cfg.n_embd self.clf_token = clf_token self.dropout = nn.Dropout(cfg.clf_pdrop) self.linear = nn.Linear(cfg.n_embd, n_class) nn.init.normal_(self.linear.weight, std = 0.02) nn.init.normal_(self.linear.bias, 0) def forward(self, h, x): clf_h = h.view(-1, self.n_embd) flat = x[..., 0].contiguous().view(-1) clf_h = clf_h[flat == self.clf_token, :] clf_h = self.dropout(clf_h) clf_logits = self.linear(clf_h) return clf_logits class SimilarityHead(nn.Module): """ Similarity Head for the transformer TODO: test this class.""" def __init__(self, clf_token, cfg): super(SimilarityHead, self).__init__() self.n_embd = cfg.n_embd self.clf_token = clf_token self.dropout = nn.Dropout(cfg.clf_pdrop) self.linear = nn.Linear(cfg.n_embd, 1) nn.init.normal_(self.linear.weight, std = 0.02) nn.init.normal_(self.linear.bias, 0) def forward(self, h, x): sim_h = h.view(-1, self.n_embd) flat = x[..., 0].contiguous().view(-1) sim_h = sim_h[flat == self.clf_token, :] sim_h = self.dropout(sim_h) sim_h = sim_h.sum(dim = 1) sim_logits = self.linear(sim_h) return sim_logits class DoubleHeadModel(nn.Module): """ Transformer with language model and task specific heads """ def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512): super(DoubleHeadModel, self).__init__() self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx) self.lm_head = LMHead(self.transformer, cfg) if isinstance(task_head_type, str): if task_head_type == 'multiple_choice': self.task_head = MultipleChoiceHead(clf_token, cfg) elif task_head_type == 'similarity': self.task_head = SimilarityHead(clf_token, cfg) elif task_head_type == 'inference': # the three classes correspond to entailment, contradiction and neutral. self.task_head = ClfHead(clf_token, cfg, 3) else: raise ValueError("task_head_type is expected to be 'multiple_choice' " "'similarity', 'inference' or ('classification', n_class) " f"got {task_head_type}.") elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \ task_head_type[0] == 'classification': n_class = task_head_type[1] self.task_head = ClfHead(clf_token, cfg, n_class) else: raise ValueError("task_head_type is expected to be 'multiple_choice' " "'similarity', 'inference' or ('classification', n_class) " f"got {task_head_type}.") def forward(self, x): h = self.transformer(x) lm_logits = self.lm_head(h) task_logits = self.task_head(h, x) return lm_logits, task_logits class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ DEFAULT_CONFIG = dotdict({ 'n_embd': 768, 'n_head': 12, 'n_layer': 12, 'embd_pdrop': 0.1, 'attn_pdrop': 0.1, 'resid_pdrop': 0.1, 'afn': 'gelu', 'clf_pdrop': 0.1})