transformers/pytorch_pretrained_bert/modeling_transfo_xl.py
2019-01-15 16:07:25 +01:00

1163 lines
46 KiB
Python

# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model.
Directly adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import collections
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from .modeling import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz",
}
CONFIG_NAME = 'transfo_xl_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
class TransfoXLConfig(object):
"""Configuration class to store the configuration of a `TransfoXLModel`.
"""
def __init__(self,
vocab_size_or_config_json_file=267735,
cutoffs=[20000, 40000, 200000],
d_model=1024,
d_embed=1024,
n_head=16,
d_head=64,
d_inner=4096,
div_val=4,
pre_lnorm=False,
n_layer=18,
tgt_len=256,
ext_len=0,
mem_len=256,
same_length=False,
attn_type=0,
clamp_len=-1,
sample_softmax=-1,
adaptive=True,
tie_weight=True,
dropout=0.1,
dropatt=0.0,
untie_r=True,
init="normal",
init_range=0.01,
proj_init_std=0.01,
init_std=0.02):
"""Constructs TransfoXLConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
cutoffs: cutoffs for the adaptive softmax
d_model: Dimensionality of the model's hidden states.
d_embed: Dimensionality of the embeddings
d_head: Dimensionality of the model's heads.
div_val: divident value for adapative input and softmax
pre_lnorm: apply LayerNorm to the input instead of the output
d_inner: Inner dimension in FF
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
tgt_len: number of tokens to predict
ext_len: length of the extended context
mem_len: length of the retained previous heads
same_length: use the same attn length for all tokens
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
clamp_len: use the same pos embeddings after clamp_len
sample_softmax: number of samples in sampled softmax
adaptive: use adaptive softmax
tie_weight: tie the word embedding and softmax weights
dropout: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
dropatt: The dropout ratio for the attention probabilities.
untie_r: untie relative position biases
embd_pdrop: The dropout ratio for the embeddings.
init: parameter initializer to use
init_range: parameters initialized by U(-init_range, init_range).
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
self.tie_projs = [False] + [True] * len(self.cutoffs)
self.d_model = d_model
self.d_embed = d_embed
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
self.clamp_len = clamp_len
self.sample_softmax = sample_softmax
self.adaptive = adaptive
self.dropout = dropout
self.dropatt = dropatt
self.untie_r = untie_r
self.init = init
self.init_range = init_range
self.proj_init_std = proj_init_std
self.init_std = init_std
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `TransfoXLConfig` from a Python dictionary of parameters."""
config = TransfoXLConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `TransfoXLConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1)
else:
return pos_emb[:,None,:]
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))
##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = self.CoreNet(inp)
##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None):
super(MultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def forward(self, h, attn_mask=None, mems=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h
if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)
head_q = self.q_net(h)
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
# [qlen x klen x bsz x n_head]
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
attn_score.mul_(self.scale)
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
return output
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None):
super(RelMultiHeadAttn, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.dropout = dropout
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
self.scale = 1 / (d_head ** 0.5)
self.pre_lnorm = pre_lnorm
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
else:
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def _parallelogram_mask(self, h, w, left=False):
mask = torch.ones((h, w)).byte()
m = min(h, w)
mask[:m,:m] = torch.triu(mask[:m,:m])
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
if left:
return mask
else:
return mask.flip(0)
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
if left:
mask = mask.flip(1)
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
else:
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
x = x_padded.masked_select(mask[:,:,None,None]) \
.view(qlen, klen, x.size(2), x.size(3))
return x
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
x = x_padded[1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
return x
def forward(self, w, r, attn_mask=None, mems=None):
raise NotImplementedError
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, attn_mask=None, mems=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
r_head_k = self.r_net(r)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
rr_head_q = w_head_q + self.r_r_bias
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score = attn_score.float().masked_fill(
attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
elif attn_mask.dim() == 3:
attn_score = attn_score.float().masked_fill(
attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
qlen, bsz = w.size(0), w.size(1)
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
else:
w_heads = self.qkv_net(cat)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
w_head_q = w_head_q[-qlen:]
else:
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(w))
else:
w_heads = self.qkv_net(w)
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
klen = w_head_k.size(0)
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
if klen > r_emb.size(0):
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
r_emb = torch.cat([r_emb_pad, r_emb], 0)
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
r_bias = torch.cat([r_bias_pad, r_bias], 0)
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]
#### compute attention score
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
BD = self._rel_shift(B_ + D_)
# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)
#### compute attention probability
if attn_mask is not None and attn_mask.any().item():
if attn_mask.dim() == 2:
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
elif attn_mask.dim() == 3:
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
# [qlen x bsz x n_head x d_head]
attn_vec = attn_vec.contiguous().view(
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
return output
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
super(DecoderLayer, self).__init__()
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
d_head, dropout, **kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None):
output = self.dec_attn(dec_inp, r,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
return output
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):
super(AdaptiveEmbedding, self).__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i)
embed = emb_flat.view(*inp.size(), self.d_proj)
embed.mul_(self.emb_scale)
return embed
class MemTransformerLM(nn.Module):
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
dropout, dropatt, tie_weight=True, d_embed=None,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False, untie_r=False,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1, **kwargs):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
d_embed = d_model if d_embed is None else d_embed
self.d_embed = d_embed
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)
self.drop = nn.Dropout(dropout)
self.n_layer = n_layer
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
self.attn_type = attn_type
if not untie_r:
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.layers = nn.ModuleList()
if attn_type == 0: # the default attention
for i in range(n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm,
r_w_bias=None if untie_r else self.r_w_bias,
r_r_bias=None if untie_r else self.r_r_bias)
)
elif attn_type == 1: # learnable embeddings
for i in range(n_layer):
self.layers.append(
RelLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm,
r_w_bias=None if untie_r else self.r_w_bias,
r_r_bias=None if untie_r else self.r_r_bias)
)
elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer):
self.layers.append(
DecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
dropatt=dropatt, pre_lnorm=pre_lnorm,
r_w_bias=None if untie_r else self.r_w_bias,
r_r_bias=None if untie_r else self.r_r_bias)
)
self.sample_softmax = sample_softmax
# use sampled softmax
if sample_softmax > 0:
self.out_layer = nn.Linear(d_model, n_token)
if tie_weight:
self.out_layer.weight = self.word_emb.weight
self.tie_weight = tie_weight
self.sampler = LogUniformSampler(n_token, sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
cutoffs, div_val=div_val)
if tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
if tie_projs:
for i, tie_proj in enumerate(tie_projs):
if tie_proj and div_val == 1 and d_model != d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
elif tie_proj and div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
self.same_length = same_length
self.clamp_len = clamp_len
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 1: # learnable
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
def backward_compatible(self):
self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
def init_mems(self):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device)
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None: return None
# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with torch.no_grad():
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = torch.cat([mems[i], hids[i]], dim=0)
new_mems.append(cat[beg_idx:end_idx].detach())
return new_mems
def _forward(self, dec_inp, mems=None):
qlen, bsz = dec_inp.size()
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
else:
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, self.r_w_bias,
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)
core_out = self.drop(word_emb + pos_emb[-qlen:])
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
hids.append(core_out)
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems
def forward(self, data, target, *mems):
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
# So, have to initialize size(0) mems inside the model forward.
# Moreover, have to return new_mems to allow nn.DataParallel to piece
# them together.
if not mems: mems = self.init_mems()
tgt_len = target.size(0)
hidden, new_mems = self._forward(data, mems=mems)
pred_hid = hidden[-tgt_len:]
if self.sample_softmax > 0 and self.training:
assert self.tie_weight
logit = sample_logits(self.word_emb,
self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
loss = loss.view(tgt_len, -1)
if new_mems is None:
return [loss]
else:
return [loss] + new_mems
class TransfoXLPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__()
if not isinstance(config, TransfoXLConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_weight(self, weight):
if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
elif self.config.init == 'normal':
nn.init.normal_(weight, 0.0, self.config.init_std)
def init_bias(self, bias):
nn.init.constant_(bias, 0.0)
def init_weights(self, m):
""" Initialize the weights.
"""
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
self.init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
self.init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
self.init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
self.init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, self.config.init_std)
if hasattr(m, 'bias') and m.bias is not None:
self.init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
self.init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
self.init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
self.init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
self.init_bias(m.r_bias)
def set_num_special_tokens(self, num_special_tokens):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
*inputs, **kwargs):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = TransfoXLConfig.from_json_file(config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
# Add additional embeddings for special tokens if needed
if num_special_tokens != config.n_special:
model.set_num_special_tokens(num_special_tokens)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class TransfoXLModel(TransfoXLPreTrainedModel):
""" Transformer XL model
From "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
by Zihang Dai*, Zhilin Yang*, Yiming Yang, William W. Cohen, Jaime Carbonell,
Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution)
Params:
config: a TransfoXLConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
# Already been converted into BPE token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
config = modeling_transfo_xl.TransfoXLConfig()
model = modeling_transfo_xl.TransfoXLModel(config)
hidden_states = model(input_ids)
```
"""
def __init__(self, config):
super(TransfoXLModel, self).__init__(config)
self.transformer = MemTransformerLM(**config.to_dict())
self.apply(self.init_weights)
def forward(self, input_ids, position_ids=None, token_type_ids=None):
return self.transformer(input_ids, position_ids, token_type_ids)