transformers/pytorch_pretrained_bert/modeling_transfo_xl.py

1108 lines
44 KiB
Python

# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-transfo_xl_config.json",
}
CONFIG_NAME = 'transfo_xl_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
class TransfoXLConfig(object):
"""Configuration class to store the configuration of a `TransfoXLModel`.
"""
def __init__(self,
vocab_size_or_config_json_file=267735,
cutoffs=[20000, 40000, 200000],
d_model=1024,
d_embed=1024,
n_head=16,
d_head=64,
d_inner=4096,
div_val=4,
pre_lnorm=False,
n_layer=18,
tgt_len=128,
ext_len=0,
mem_len=1600,
clamp_len=1000,
same_length=True,
proj_share_all_but_first=True,
attn_type=0,
sample_softmax=-1,
adaptive=True,
tie_weight=True,
dropout=0.1,
dropatt=0.0,
untie_r=True,
init="normal",
init_range=0.01,
proj_init_std=0.01,
init_std=0.02):
"""Constructs TransfoXLConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
cutoffs: cutoffs for the adaptive softmax
d_model: Dimensionality of the model's hidden states.
d_embed: Dimensionality of the embeddings
d_head: Dimensionality of the model's heads.
div_val: divident value for adapative input and softmax
pre_lnorm: apply LayerNorm to the input instead of the output
d_inner: Inner dimension in FF
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
tgt_len: number of tokens to predict
ext_len: length of the extended context
mem_len: length of the retained previous heads
same_length: use the same attn length for all tokens
proj_share_all_but_first: True to share all but first projs, False not to share.
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
clamp_len: use the same pos embeddings after clamp_len
sample_softmax: number of samples in sampled softmax
adaptive: use adaptive softmax
tie_weight: tie the word embedding and softmax weights
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention probabilities.
untie_r: untie relative position biases
embd_pdrop: The dropout ratio for the embeddings.
init: parameter initializer to use
init_range: parameters initialized by U(-init_range, init_range).
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
if proj_share_all_but_first:
self.tie_projs = [False] + [True] * len(self.cutoffs)
else:
self.tie_projs = [False] + [False] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `TransfoXLConfig` from a Python dictionary of parameters."""
config = TransfoXLConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `TransfoXLConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1)
else:
return pos_emb[:,None,:]
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = self.CoreNet(inp)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(h)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def _parallelogram_mask(self, h, w, left=False):
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m,:m] = torch.triu(mask[:m,:m])
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
if left:
return mask
else:
return mask.flip(0)
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:,:,None,None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
x = x_padded[1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
return x
def forward(self, w, r, attn_mask=None, mems=None):
raise NotImplementedError
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, attn_mask=None, mems=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + self.r_r_bias
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
if klen > r_emb.size(0):
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
r_emb = torch.cat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
r_bias = torch.cat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
#### compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
BD = self._rel_shift(B_ + D_)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i)
embed = emb_flat.view(*inp.size(), self.d_proj)
embed.mul_(self.emb_scale)
return embed
class TransfoXLPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__()
if not isinstance(config, TransfoXLConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_weight(self, weight):
if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
elif self.config.init == 'normal':
nn.init.normal_(weight, 0.0, self.config.init_std)
def init_bias(self, bias):
nn.init.constant_(bias, 0.0)
def init_weights(self, m):
""" Initialize the weights.
"""
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
self.init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
self.init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
self.init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
self.init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
self.init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
self.init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
self.init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
self.init_bias(m.r_bias)
def set_num_special_tokens(self, num_special_tokens):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
*inputs, **kwargs):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
archive_file, config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = TransfoXLConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
state_dict = torch.load(resolved_archive_file)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
return model
class TransfoXLModel(TransfoXLPreTrainedModel):
def __init__(self, config):
# n_token, n_layer, n_head, d_model, d_head, d_inner,
# dropout, dropatt, tie_weight=True, d_embed=None,
# div_val=1, tie_projs=[False], pre_lnorm=False,
# tgt_len=None, ext_len=None, mem_len=None,
# cutoffs=[], adapt_inp=False, untie_r=False,
# same_length=False, attn_type=0, clamp_len=-1,
# sample_softmax=-1, **kwargs):
super(TransfoXLModel, self).__init__(config)
self.n_token = config.n_token
self.d_embed = config.d_embed
self.d_model = config.d_model
self.n_head = config.n_head
self.d_head = config.d_head
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val)
self.drop = nn.Dropout(config.dropout)
self.n_layer = config.n_layer
self.tgt_len = config.tgt_len
self.mem_len = config.mem_len
self.ext_len = config.ext_len
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
self.attn_type = config.attn_type
if not config.untie_r:
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.layers = nn.ModuleList()
if config.attn_type == 0: # the default attention
for i in range(config.n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias)
)
elif config.attn_type == 1: # learnable embeddings
for i in range(config.n_layer):
self.layers.append(
RelLearnableDecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias)
)
elif config.attn_type in [2, 3]: # absolute embeddings
for i in range(config.n_layer):
self.layers.append(
DecoderLayer(
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias)
)
self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0:
self.out_layer = nn.Linear(config.d_model, config.n_token)
if config.tie_weight:
self.out_layer.weight = self.word_emb.weight
self.tie_weight = config.tie_weight
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val)
if config.tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
if config.tie_projs:
for i, tie_proj in enumerate(config.tie_projs):
if tie_proj and config.div_val == 1 and config.d_model != config.d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
elif tie_proj and config.div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
self.same_length = config.same_length
self.clamp_len = config.clamp_len
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 1: # learnable
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
def backward_compatible(self):
self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
def init_mems(self):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device)
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None: return None
# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with torch.no_grad():
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = torch.cat([mems[i], hids[i]], dim=0)
new_mems.append(cat[beg_idx:end_idx].detach())
return new_mems
def _forward(self, dec_inp, mems=None):
qlen, bsz = dec_inp.size()
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems
def forward(self, data, target=None, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if not mems: mems = self.init_mems()
hidden, new_mems = self._forward(data, mems=mems)
if target is None:
if new_mems is None:
return [hidden]
else:
return [hidden] + new_mems
tgt_len = target.size(0)
pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training:
assert self.tie_weight
logit = sample_logits(self.word_emb, self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
loss = loss.view(tgt_len, -1)
if new_mems is None:
return [loss]
else:
return [loss] + new_mems