mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 05:40:05 +06:00
1163 lines
46 KiB
Python
1163 lines
46 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch Transformer XL model.
|
|
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
|
"""
|
|
|
|
import os
|
|
import copy
|
|
import json
|
|
import math
|
|
import logging
|
|
import tarfile
|
|
import tempfile
|
|
import shutil
|
|
import collections
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import CrossEntropyLoss
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from .modeling import BertLayerNorm as LayerNorm
|
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
|
from .file_utils import cached_path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl.tar.gz",
|
|
}
|
|
CONFIG_NAME = 'transfo_xl_config.json'
|
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
|
|
|
class TransfoXLConfig(object):
|
|
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
|
"""
|
|
def __init__(self,
|
|
vocab_size_or_config_json_file=267735,
|
|
cutoffs=[20000, 40000, 200000],
|
|
d_model=1024,
|
|
d_embed=1024,
|
|
n_head=16,
|
|
d_head=64,
|
|
d_inner=4096,
|
|
div_val=4,
|
|
pre_lnorm=False,
|
|
n_layer=18,
|
|
tgt_len=256,
|
|
ext_len=0,
|
|
mem_len=256,
|
|
same_length=False,
|
|
attn_type=0,
|
|
clamp_len=-1,
|
|
sample_softmax=-1,
|
|
adaptive=True,
|
|
tie_weight=True,
|
|
dropout=0.1,
|
|
dropatt=0.0,
|
|
untie_r=True,
|
|
init="normal",
|
|
init_range=0.01,
|
|
proj_init_std=0.01,
|
|
init_std=0.02):
|
|
"""Constructs TransfoXLConfig.
|
|
|
|
Args:
|
|
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
|
|
cutoffs: cutoffs for the adaptive softmax
|
|
d_model: Dimensionality of the model's hidden states.
|
|
d_embed: Dimensionality of the embeddings
|
|
d_head: Dimensionality of the model's heads.
|
|
div_val: divident value for adapative input and softmax
|
|
pre_lnorm: apply LayerNorm to the input instead of the output
|
|
d_inner: Inner dimension in FF
|
|
n_layer: Number of hidden layers in the Transformer encoder.
|
|
n_head: Number of attention heads for each attention layer in
|
|
the Transformer encoder.
|
|
tgt_len: number of tokens to predict
|
|
ext_len: length of the extended context
|
|
mem_len: length of the retained previous heads
|
|
same_length: use the same attn length for all tokens
|
|
attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
|
|
clamp_len: use the same pos embeddings after clamp_len
|
|
sample_softmax: number of samples in sampled softmax
|
|
adaptive: use adaptive softmax
|
|
tie_weight: tie the word embedding and softmax weights
|
|
dropout: The dropout probabilitiy for all fully connected
|
|
layers in the embeddings, encoder, and pooler.
|
|
dropatt: The dropout ratio for the attention probabilities.
|
|
untie_r: untie relative position biases
|
|
embd_pdrop: The dropout ratio for the embeddings.
|
|
init: parameter initializer to use
|
|
init_range: parameters initialized by U(-init_range, init_range).
|
|
proj_init_std: parameters initialized by N(0, init_std)
|
|
init_std: parameters initialized by N(0, init_std)
|
|
"""
|
|
if isinstance(vocab_size_or_config_json_file, str):
|
|
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
|
json_config = json.loads(reader.read())
|
|
for key, value in json_config.items():
|
|
self.__dict__[key] = value
|
|
elif isinstance(vocab_size_or_config_json_file, int):
|
|
self.n_token = vocab_size_or_config_json_file
|
|
self.cutoffs = []
|
|
self.cutoffs.extend(cutoffs)
|
|
self.tie_weight = tie_weight
|
|
self.tie_projs = [False] + [True] * len(self.cutoffs)
|
|
self.d_model = d_model
|
|
self.d_embed = d_embed
|
|
self.d_head = d_head
|
|
self.d_inner = d_inner
|
|
self.div_val = div_val
|
|
self.pre_lnorm = pre_lnorm
|
|
self.n_layer = n_layer
|
|
self.n_head = n_head
|
|
self.tgt_len = tgt_len
|
|
self.ext_len = ext_len
|
|
self.mem_len = mem_len
|
|
self.same_length = same_length
|
|
self.attn_type = attn_type
|
|
self.clamp_len = clamp_len
|
|
self.sample_softmax = sample_softmax
|
|
self.adaptive = adaptive
|
|
self.dropout = dropout
|
|
self.dropatt = dropatt
|
|
self.untie_r = untie_r
|
|
self.init = init
|
|
self.init_range = init_range
|
|
self.proj_init_std = proj_init_std
|
|
self.init_std = init_std
|
|
else:
|
|
raise ValueError("First argument must be either a vocabulary size (int)"
|
|
"or the path to a pretrained model config file (str)")
|
|
|
|
@classmethod
|
|
def from_dict(cls, json_object):
|
|
"""Constructs a `TransfoXLConfig` from a Python dictionary of parameters."""
|
|
config = TransfoXLConfig(vocab_size_or_config_json_file=-1)
|
|
for key, value in json_object.items():
|
|
config.__dict__[key] = value
|
|
return config
|
|
|
|
@classmethod
|
|
def from_json_file(cls, json_file):
|
|
"""Constructs a `TransfoXLConfig` from a json file of parameters."""
|
|
with open(json_file, "r", encoding='utf-8') as reader:
|
|
text = reader.read()
|
|
return cls.from_dict(json.loads(text))
|
|
|
|
def __repr__(self):
|
|
return str(self.to_json_string())
|
|
|
|
def to_dict(self):
|
|
"""Serializes this instance to a Python dictionary."""
|
|
output = copy.deepcopy(self.__dict__)
|
|
return output
|
|
|
|
def to_json_string(self):
|
|
"""Serializes this instance to a JSON string."""
|
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
|
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, demb):
|
|
super(PositionalEmbedding, self).__init__()
|
|
|
|
self.demb = demb
|
|
|
|
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
|
self.register_buffer('inv_freq', inv_freq)
|
|
|
|
def forward(self, pos_seq, bsz=None):
|
|
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
|
|
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
|
|
|
|
if bsz is not None:
|
|
return pos_emb[:,None,:].expand(-1, bsz, -1)
|
|
else:
|
|
return pos_emb[:,None,:]
|
|
|
|
|
|
class PositionwiseFF(nn.Module):
|
|
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
|
|
super(PositionwiseFF, self).__init__()
|
|
|
|
self.d_model = d_model
|
|
self.d_inner = d_inner
|
|
self.dropout = dropout
|
|
|
|
self.CoreNet = nn.Sequential(
|
|
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(d_inner, d_model),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
def forward(self, inp):
|
|
if self.pre_lnorm:
|
|
##### layer normalization + positionwise feed-forward
|
|
core_out = self.CoreNet(self.layer_norm(inp))
|
|
|
|
##### residual connection
|
|
output = core_out + inp
|
|
else:
|
|
##### positionwise feed-forward
|
|
core_out = self.CoreNet(inp)
|
|
|
|
##### residual connection + layer normalization
|
|
output = self.layer_norm(inp + core_out)
|
|
|
|
return output
|
|
|
|
class MultiHeadAttn(nn.Module):
|
|
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
|
pre_lnorm=False, r_r_bias=None, r_w_bias=None):
|
|
super(MultiHeadAttn, self).__init__()
|
|
|
|
self.n_head = n_head
|
|
self.d_model = d_model
|
|
self.d_head = d_head
|
|
self.dropout = dropout
|
|
|
|
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
|
|
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
|
|
|
|
self.drop = nn.Dropout(dropout)
|
|
self.dropatt = nn.Dropout(dropatt)
|
|
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
self.scale = 1 / (d_head ** 0.5)
|
|
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
if r_r_bias is None or r_w_bias is None: # Biases are not shared
|
|
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
else:
|
|
self.r_r_bias = r_r_bias
|
|
self.r_w_bias = r_w_bias
|
|
|
|
def forward(self, h, attn_mask=None, mems=None):
|
|
##### multihead attention
|
|
# [hlen x bsz x n_head x d_head]
|
|
|
|
if mems is not None:
|
|
c = torch.cat([mems, h], 0)
|
|
else:
|
|
c = h
|
|
|
|
if self.pre_lnorm:
|
|
##### layer normalization
|
|
c = self.layer_norm(c)
|
|
|
|
head_q = self.q_net(h)
|
|
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
|
|
|
|
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
|
|
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
|
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
|
|
attn_score.mul_(self.scale)
|
|
if attn_mask is not None and attn_mask.any().item():
|
|
if attn_mask.dim() == 2:
|
|
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
|
elif attn_mask.dim() == 3:
|
|
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_prob = F.softmax(attn_score, dim=1)
|
|
attn_prob = self.dropatt(attn_prob)
|
|
|
|
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
|
|
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
|
|
attn_vec = attn_vec.contiguous().view(
|
|
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
|
|
|
##### linear projection
|
|
attn_out = self.o_net(attn_vec)
|
|
attn_out = self.drop(attn_out)
|
|
|
|
if self.pre_lnorm:
|
|
##### residual connection
|
|
output = h + attn_out
|
|
else:
|
|
##### residual connection + layer normalization
|
|
output = self.layer_norm(h + attn_out)
|
|
|
|
return output
|
|
|
|
class RelMultiHeadAttn(nn.Module):
|
|
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
|
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
|
|
r_r_bias=None, r_w_bias=None):
|
|
super(RelMultiHeadAttn, self).__init__()
|
|
|
|
self.n_head = n_head
|
|
self.d_model = d_model
|
|
self.d_head = d_head
|
|
self.dropout = dropout
|
|
|
|
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
|
|
|
|
self.drop = nn.Dropout(dropout)
|
|
self.dropatt = nn.Dropout(dropatt)
|
|
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
|
|
|
self.layer_norm = nn.LayerNorm(d_model)
|
|
|
|
self.scale = 1 / (d_head ** 0.5)
|
|
|
|
self.pre_lnorm = pre_lnorm
|
|
|
|
if r_r_bias is None or r_w_bias is None: # Biases are not shared
|
|
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
else:
|
|
self.r_r_bias = r_r_bias
|
|
self.r_w_bias = r_w_bias
|
|
|
|
def _parallelogram_mask(self, h, w, left=False):
|
|
mask = torch.ones((h, w)).byte()
|
|
m = min(h, w)
|
|
mask[:m,:m] = torch.triu(mask[:m,:m])
|
|
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
|
|
|
|
if left:
|
|
return mask
|
|
else:
|
|
return mask.flip(0)
|
|
|
|
def _shift(self, x, qlen, klen, mask, left=False):
|
|
if qlen > 1:
|
|
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
|
|
device=x.device, dtype=x.dtype)
|
|
else:
|
|
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
|
|
|
|
if left:
|
|
mask = mask.flip(1)
|
|
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
|
|
else:
|
|
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
|
|
|
|
x = x_padded.masked_select(mask[:,:,None,None]) \
|
|
.view(qlen, klen, x.size(2), x.size(3))
|
|
|
|
return x
|
|
|
|
def _rel_shift(self, x, zero_triu=False):
|
|
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
|
device=x.device, dtype=x.dtype)
|
|
x_padded = torch.cat([zero_pad, x], dim=1)
|
|
|
|
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
|
|
|
x = x_padded[1:].view_as(x)
|
|
|
|
if zero_triu:
|
|
ones = torch.ones((x.size(0), x.size(1)))
|
|
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
|
|
|
|
return x
|
|
|
|
def forward(self, w, r, attn_mask=None, mems=None):
|
|
raise NotImplementedError
|
|
|
|
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|
def __init__(self, *args, **kwargs):
|
|
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
|
|
|
|
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
|
|
|
|
def forward(self, w, r, attn_mask=None, mems=None):
|
|
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
|
|
|
|
if mems is not None:
|
|
cat = torch.cat([mems, w], 0)
|
|
if self.pre_lnorm:
|
|
w_heads = self.qkv_net(self.layer_norm(cat))
|
|
else:
|
|
w_heads = self.qkv_net(cat)
|
|
r_head_k = self.r_net(r)
|
|
|
|
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
|
w_head_q = w_head_q[-qlen:]
|
|
else:
|
|
if self.pre_lnorm:
|
|
w_heads = self.qkv_net(self.layer_norm(w))
|
|
else:
|
|
w_heads = self.qkv_net(w)
|
|
r_head_k = self.r_net(r)
|
|
|
|
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
|
|
|
klen = w_head_k.size(0)
|
|
|
|
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
|
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
|
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
|
|
|
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
|
|
|
#### compute attention score
|
|
rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
|
|
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
|
|
|
rr_head_q = w_head_q + self.r_r_bias
|
|
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
|
|
BD = self._rel_shift(BD)
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_score = AC + BD
|
|
attn_score.mul_(self.scale)
|
|
|
|
#### compute attention probability
|
|
if attn_mask is not None and attn_mask.any().item():
|
|
if attn_mask.dim() == 2:
|
|
attn_score = attn_score.float().masked_fill(
|
|
attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
|
|
elif attn_mask.dim() == 3:
|
|
attn_score = attn_score.float().masked_fill(
|
|
attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_prob = F.softmax(attn_score, dim=1)
|
|
attn_prob = self.dropatt(attn_prob)
|
|
|
|
#### compute attention vector
|
|
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
|
|
|
|
# [qlen x bsz x n_head x d_head]
|
|
attn_vec = attn_vec.contiguous().view(
|
|
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
|
|
|
##### linear projection
|
|
attn_out = self.o_net(attn_vec)
|
|
attn_out = self.drop(attn_out)
|
|
|
|
if self.pre_lnorm:
|
|
##### residual connection
|
|
output = w + attn_out
|
|
else:
|
|
##### residual connection + layer normalization
|
|
output = self.layer_norm(w + attn_out)
|
|
|
|
return output
|
|
|
|
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
|
def __init__(self, *args, **kwargs):
|
|
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
|
|
|
|
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
|
|
# r_emb: [klen, n_head, d_head], used for term B
|
|
# r_w_bias: [n_head, d_head], used for term C
|
|
# r_bias: [klen, n_head], used for term D
|
|
|
|
qlen, bsz = w.size(0), w.size(1)
|
|
|
|
if mems is not None:
|
|
cat = torch.cat([mems, w], 0)
|
|
if self.pre_lnorm:
|
|
w_heads = self.qkv_net(self.layer_norm(cat))
|
|
else:
|
|
w_heads = self.qkv_net(cat)
|
|
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
|
|
|
w_head_q = w_head_q[-qlen:]
|
|
else:
|
|
if self.pre_lnorm:
|
|
w_heads = self.qkv_net(self.layer_norm(w))
|
|
else:
|
|
w_heads = self.qkv_net(w)
|
|
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
|
|
|
klen = w_head_k.size(0)
|
|
|
|
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
|
|
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
|
|
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
|
|
|
|
if klen > r_emb.size(0):
|
|
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
|
|
r_emb = torch.cat([r_emb_pad, r_emb], 0)
|
|
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
|
|
r_bias = torch.cat([r_bias_pad, r_bias], 0)
|
|
else:
|
|
r_emb = r_emb[-klen:]
|
|
r_bias = r_bias[-klen:]
|
|
|
|
#### compute attention score
|
|
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
|
|
|
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
|
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
|
|
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
|
|
BD = self._rel_shift(B_ + D_)
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_score = AC + BD
|
|
attn_score.mul_(self.scale)
|
|
|
|
#### compute attention probability
|
|
if attn_mask is not None and attn_mask.any().item():
|
|
if attn_mask.dim() == 2:
|
|
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
|
elif attn_mask.dim() == 3:
|
|
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
|
|
|
|
# [qlen x klen x bsz x n_head]
|
|
attn_prob = F.softmax(attn_score, dim=1)
|
|
attn_prob = self.dropatt(attn_prob)
|
|
|
|
#### compute attention vector
|
|
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
|
|
|
|
# [qlen x bsz x n_head x d_head]
|
|
attn_vec = attn_vec.contiguous().view(
|
|
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
|
|
|
##### linear projection
|
|
attn_out = self.o_net(attn_vec)
|
|
attn_out = self.drop(attn_out)
|
|
|
|
if self.pre_lnorm:
|
|
##### residual connection
|
|
output = w + attn_out
|
|
else:
|
|
##### residual connection + layer normalization
|
|
output = self.layer_norm(w + attn_out)
|
|
|
|
return output
|
|
|
|
class DecoderLayer(nn.Module):
|
|
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
|
|
super(DecoderLayer, self).__init__()
|
|
|
|
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
|
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
|
pre_lnorm=kwargs.get('pre_lnorm'))
|
|
|
|
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
|
|
|
|
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
|
|
mems=mems)
|
|
output = self.pos_ff(output)
|
|
|
|
return output
|
|
|
|
class RelLearnableDecoderLayer(nn.Module):
|
|
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
|
|
**kwargs):
|
|
super(RelLearnableDecoderLayer, self).__init__()
|
|
|
|
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
|
|
**kwargs)
|
|
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
|
pre_lnorm=kwargs.get('pre_lnorm'))
|
|
|
|
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
|
|
|
|
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
|
|
attn_mask=dec_attn_mask,
|
|
mems=mems)
|
|
output = self.pos_ff(output)
|
|
|
|
return output
|
|
|
|
class RelPartialLearnableDecoderLayer(nn.Module):
|
|
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
|
|
**kwargs):
|
|
super(RelPartialLearnableDecoderLayer, self).__init__()
|
|
|
|
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
|
|
d_head, dropout, **kwargs)
|
|
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
|
pre_lnorm=kwargs.get('pre_lnorm'))
|
|
|
|
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None):
|
|
|
|
output = self.dec_attn(dec_inp, r,
|
|
attn_mask=dec_attn_mask,
|
|
mems=mems)
|
|
output = self.pos_ff(output)
|
|
|
|
return output
|
|
|
|
|
|
class AdaptiveEmbedding(nn.Module):
|
|
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
|
sample_softmax=False):
|
|
super(AdaptiveEmbedding, self).__init__()
|
|
|
|
self.n_token = n_token
|
|
self.d_embed = d_embed
|
|
|
|
self.cutoffs = cutoffs + [n_token]
|
|
self.div_val = div_val
|
|
self.d_proj = d_proj
|
|
|
|
self.emb_scale = d_proj ** 0.5
|
|
|
|
self.cutoff_ends = [0] + self.cutoffs
|
|
|
|
self.emb_layers = nn.ModuleList()
|
|
self.emb_projs = nn.ParameterList()
|
|
if div_val == 1:
|
|
self.emb_layers.append(
|
|
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
|
|
)
|
|
if d_proj != d_embed:
|
|
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
|
|
else:
|
|
for i in range(len(self.cutoffs)):
|
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
|
d_emb_i = d_embed // (div_val ** i)
|
|
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
|
|
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
|
|
|
|
def forward(self, inp):
|
|
if self.div_val == 1:
|
|
embed = self.emb_layers[0](inp)
|
|
if self.d_proj != self.d_embed:
|
|
embed = F.linear(embed, self.emb_projs[0])
|
|
else:
|
|
param = next(self.parameters())
|
|
inp_flat = inp.view(-1)
|
|
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
|
|
dtype=param.dtype, device=param.device)
|
|
for i in range(len(self.cutoffs)):
|
|
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
|
|
|
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
|
|
indices_i = mask_i.nonzero().squeeze()
|
|
|
|
if indices_i.numel() == 0:
|
|
continue
|
|
|
|
inp_i = inp_flat.index_select(0, indices_i) - l_idx
|
|
emb_i = self.emb_layers[i](inp_i)
|
|
emb_i = F.linear(emb_i, self.emb_projs[i])
|
|
|
|
emb_flat.index_copy_(0, indices_i, emb_i)
|
|
|
|
embed = emb_flat.view(*inp.size(), self.d_proj)
|
|
|
|
embed.mul_(self.emb_scale)
|
|
|
|
return embed
|
|
|
|
class MemTransformerLM(nn.Module):
|
|
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
|
|
dropout, dropatt, tie_weight=True, d_embed=None,
|
|
div_val=1, tie_projs=[False], pre_lnorm=False,
|
|
tgt_len=None, ext_len=None, mem_len=None,
|
|
cutoffs=[], adapt_inp=False, untie_r=False,
|
|
same_length=False, attn_type=0, clamp_len=-1,
|
|
sample_softmax=-1, **kwargs):
|
|
super(MemTransformerLM, self).__init__()
|
|
self.n_token = n_token
|
|
|
|
d_embed = d_model if d_embed is None else d_embed
|
|
self.d_embed = d_embed
|
|
self.d_model = d_model
|
|
self.n_head = n_head
|
|
self.d_head = d_head
|
|
|
|
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
|
|
div_val=div_val)
|
|
|
|
self.drop = nn.Dropout(dropout)
|
|
|
|
self.n_layer = n_layer
|
|
|
|
self.tgt_len = tgt_len
|
|
self.mem_len = mem_len
|
|
self.ext_len = ext_len
|
|
self.max_klen = tgt_len + ext_len + mem_len
|
|
|
|
self.attn_type = attn_type
|
|
|
|
if not untie_r:
|
|
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
|
|
|
self.layers = nn.ModuleList()
|
|
if attn_type == 0: # the default attention
|
|
for i in range(n_layer):
|
|
self.layers.append(
|
|
RelPartialLearnableDecoderLayer(
|
|
n_head, d_model, d_head, d_inner, dropout,
|
|
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
|
dropatt=dropatt, pre_lnorm=pre_lnorm,
|
|
r_w_bias=None if untie_r else self.r_w_bias,
|
|
r_r_bias=None if untie_r else self.r_r_bias)
|
|
)
|
|
elif attn_type == 1: # learnable embeddings
|
|
for i in range(n_layer):
|
|
self.layers.append(
|
|
RelLearnableDecoderLayer(
|
|
n_head, d_model, d_head, d_inner, dropout,
|
|
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
|
dropatt=dropatt, pre_lnorm=pre_lnorm,
|
|
r_w_bias=None if untie_r else self.r_w_bias,
|
|
r_r_bias=None if untie_r else self.r_r_bias)
|
|
)
|
|
elif attn_type in [2, 3]: # absolute embeddings
|
|
for i in range(n_layer):
|
|
self.layers.append(
|
|
DecoderLayer(
|
|
n_head, d_model, d_head, d_inner, dropout,
|
|
dropatt=dropatt, pre_lnorm=pre_lnorm,
|
|
r_w_bias=None if untie_r else self.r_w_bias,
|
|
r_r_bias=None if untie_r else self.r_r_bias)
|
|
)
|
|
|
|
self.sample_softmax = sample_softmax
|
|
# use sampled softmax
|
|
if sample_softmax > 0:
|
|
self.out_layer = nn.Linear(d_model, n_token)
|
|
if tie_weight:
|
|
self.out_layer.weight = self.word_emb.weight
|
|
self.tie_weight = tie_weight
|
|
self.sampler = LogUniformSampler(n_token, sample_softmax)
|
|
|
|
# use adaptive softmax (including standard softmax)
|
|
else:
|
|
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
|
|
cutoffs, div_val=div_val)
|
|
|
|
if tie_weight:
|
|
for i in range(len(self.crit.out_layers)):
|
|
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
|
|
|
|
if tie_projs:
|
|
for i, tie_proj in enumerate(tie_projs):
|
|
if tie_proj and div_val == 1 and d_model != d_embed:
|
|
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
|
|
elif tie_proj and div_val != 1:
|
|
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
|
|
|
|
self.same_length = same_length
|
|
self.clamp_len = clamp_len
|
|
|
|
if self.attn_type == 0: # default attention
|
|
self.pos_emb = PositionalEmbedding(self.d_model)
|
|
elif self.attn_type == 1: # learnable
|
|
self.r_emb = nn.Parameter(torch.Tensor(
|
|
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
|
self.r_bias = nn.Parameter(torch.Tensor(
|
|
self.n_layer, self.max_klen, self.n_head))
|
|
elif self.attn_type == 2: # absolute standard
|
|
self.pos_emb = PositionalEmbedding(self.d_model)
|
|
elif self.attn_type == 3: # absolute deeper SA
|
|
self.r_emb = nn.Parameter(torch.Tensor(
|
|
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
|
|
|
def backward_compatible(self):
|
|
self.sample_softmax = -1
|
|
|
|
|
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
|
self.tgt_len = tgt_len
|
|
self.mem_len = mem_len
|
|
self.ext_len = ext_len
|
|
|
|
def init_mems(self):
|
|
if self.mem_len > 0:
|
|
mems = []
|
|
param = next(self.parameters())
|
|
for i in range(self.n_layer+1):
|
|
empty = torch.empty(0, dtype=param.dtype, device=param.device)
|
|
mems.append(empty)
|
|
|
|
return mems
|
|
else:
|
|
return None
|
|
|
|
def _update_mems(self, hids, mems, qlen, mlen):
|
|
# does not deal with None
|
|
if mems is None: return None
|
|
|
|
# mems is not None
|
|
assert len(hids) == len(mems), 'len(hids) != len(mems)'
|
|
|
|
# There are `mlen + qlen` steps that can be cached into mems
|
|
# For the next step, the last `ext_len` of the `qlen` tokens
|
|
# will be used as the extended context. Hence, we only cache
|
|
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
|
# to `mlen + qlen - self.ext_len`.
|
|
with torch.no_grad():
|
|
new_mems = []
|
|
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
|
beg_idx = max(0, end_idx - self.mem_len)
|
|
for i in range(len(hids)):
|
|
|
|
cat = torch.cat([mems[i], hids[i]], dim=0)
|
|
new_mems.append(cat[beg_idx:end_idx].detach())
|
|
|
|
return new_mems
|
|
|
|
def _forward(self, dec_inp, mems=None):
|
|
qlen, bsz = dec_inp.size()
|
|
|
|
word_emb = self.word_emb(dec_inp)
|
|
|
|
mlen = mems[0].size(0) if mems is not None else 0
|
|
klen = mlen + qlen
|
|
if self.same_length:
|
|
all_ones = word_emb.new_ones(qlen, klen)
|
|
mask_len = klen - self.mem_len
|
|
if mask_len > 0:
|
|
mask_shift_len = qlen - mask_len
|
|
else:
|
|
mask_shift_len = qlen
|
|
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
|
|
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
|
|
else:
|
|
dec_attn_mask = torch.triu(
|
|
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
|
|
|
|
hids = []
|
|
if self.attn_type == 0: # default
|
|
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
|
|
dtype=word_emb.dtype)
|
|
if self.clamp_len > 0:
|
|
pos_seq.clamp_(max=self.clamp_len)
|
|
pos_emb = self.pos_emb(pos_seq)
|
|
|
|
core_out = self.drop(word_emb)
|
|
pos_emb = self.drop(pos_emb)
|
|
|
|
hids.append(core_out)
|
|
for i, layer in enumerate(self.layers):
|
|
mems_i = None if mems is None else mems[i]
|
|
core_out = layer(core_out, pos_emb, self.r_w_bias,
|
|
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
|
hids.append(core_out)
|
|
elif self.attn_type == 1: # learnable
|
|
core_out = self.drop(word_emb)
|
|
hids.append(core_out)
|
|
for i, layer in enumerate(self.layers):
|
|
if self.clamp_len > 0:
|
|
r_emb = self.r_emb[i][-self.clamp_len :]
|
|
r_bias = self.r_bias[i][-self.clamp_len :]
|
|
else:
|
|
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
|
|
|
|
mems_i = None if mems is None else mems[i]
|
|
core_out = layer(core_out, r_emb, self.r_w_bias[i],
|
|
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
|
hids.append(core_out)
|
|
elif self.attn_type == 2: # absolute
|
|
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
|
|
dtype=word_emb.dtype)
|
|
if self.clamp_len > 0:
|
|
pos_seq.clamp_(max=self.clamp_len)
|
|
pos_emb = self.pos_emb(pos_seq)
|
|
|
|
core_out = self.drop(word_emb + pos_emb[-qlen:])
|
|
|
|
hids.append(core_out)
|
|
for i, layer in enumerate(self.layers):
|
|
mems_i = None if mems is None else mems[i]
|
|
if mems_i is not None and i == 0:
|
|
mems_i += pos_emb[:mlen]
|
|
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
|
mems=mems_i)
|
|
hids.append(core_out)
|
|
elif self.attn_type == 3:
|
|
core_out = self.drop(word_emb)
|
|
|
|
hids.append(core_out)
|
|
for i, layer in enumerate(self.layers):
|
|
mems_i = None if mems is None else mems[i]
|
|
if mems_i is not None and mlen > 0:
|
|
cur_emb = self.r_emb[i][:-qlen]
|
|
cur_size = cur_emb.size(0)
|
|
if cur_size < mlen:
|
|
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
|
|
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
|
|
else:
|
|
cur_emb = cur_emb[-mlen:]
|
|
mems_i += cur_emb.view(mlen, 1, -1)
|
|
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
|
|
|
|
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
|
mems=mems_i)
|
|
hids.append(core_out)
|
|
|
|
core_out = self.drop(core_out)
|
|
|
|
new_mems = self._update_mems(hids, mems, mlen, qlen)
|
|
|
|
return core_out, new_mems
|
|
|
|
def forward(self, data, target, *mems):
|
|
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
|
# So, have to initialize size(0) mems inside the model forward.
|
|
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
|
# them together.
|
|
if not mems: mems = self.init_mems()
|
|
|
|
tgt_len = target.size(0)
|
|
hidden, new_mems = self._forward(data, mems=mems)
|
|
|
|
pred_hid = hidden[-tgt_len:]
|
|
if self.sample_softmax > 0 and self.training:
|
|
assert self.tie_weight
|
|
logit = sample_logits(self.word_emb,
|
|
self.out_layer.bias, target, pred_hid, self.sampler)
|
|
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
|
else:
|
|
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
|
|
loss = loss.view(tgt_len, -1)
|
|
|
|
if new_mems is None:
|
|
return [loss]
|
|
else:
|
|
return [loss] + new_mems
|
|
|
|
|
|
class TransfoXLPreTrainedModel(nn.Module):
|
|
""" An abstract class to handle weights initialization and
|
|
a simple interface for dowloading and loading pretrained models.
|
|
"""
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TransfoXLPreTrainedModel, self).__init__()
|
|
if not isinstance(config, TransfoXLConfig):
|
|
raise ValueError(
|
|
"Parameter config in `{}(config)` should be an instance of class `TransfoXLConfig`. "
|
|
"To create a model from a pretrained model use "
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
|
self.__class__.__name__, self.__class__.__name__
|
|
))
|
|
self.config = config
|
|
|
|
def init_weight(self, weight):
|
|
if self.config.init == 'uniform':
|
|
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
|
|
elif self.config.init == 'normal':
|
|
nn.init.normal_(weight, 0.0, self.config.init_std)
|
|
|
|
def init_bias(self, bias):
|
|
nn.init.constant_(bias, 0.0)
|
|
|
|
def init_weights(self, m):
|
|
""" Initialize the weights.
|
|
"""
|
|
classname = m.__class__.__name__
|
|
if classname.find('Linear') != -1:
|
|
if hasattr(m, 'weight') and m.weight is not None:
|
|
self.init_weight(m.weight)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
self.init_bias(m.bias)
|
|
elif classname.find('AdaptiveEmbedding') != -1:
|
|
if hasattr(m, 'emb_projs'):
|
|
for i in range(len(m.emb_projs)):
|
|
if m.emb_projs[i] is not None:
|
|
nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
|
|
elif classname.find('Embedding') != -1:
|
|
if hasattr(m, 'weight'):
|
|
self.init_weight(m.weight)
|
|
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
|
|
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
|
|
self.init_weight(m.cluster_weight)
|
|
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
|
|
self.init_bias(m.cluster_bias)
|
|
if hasattr(m, 'out_projs'):
|
|
for i in range(len(m.out_projs)):
|
|
if m.out_projs[i] is not None:
|
|
nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
|
|
elif classname.find('LayerNorm') != -1:
|
|
if hasattr(m, 'weight'):
|
|
nn.init.normal_(m.weight, 1.0, self.config.init_std)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
self.init_bias(m.bias)
|
|
elif classname.find('TransformerLM') != -1:
|
|
if hasattr(m, 'r_emb'):
|
|
self.init_weight(m.r_emb)
|
|
if hasattr(m, 'r_w_bias'):
|
|
self.init_weight(m.r_w_bias)
|
|
if hasattr(m, 'r_r_bias'):
|
|
self.init_weight(m.r_r_bias)
|
|
if hasattr(m, 'r_bias'):
|
|
self.init_bias(m.r_bias)
|
|
|
|
def set_num_special_tokens(self, num_special_tokens):
|
|
pass
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
|
|
*inputs, **kwargs):
|
|
"""
|
|
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
|
Download and cache the pre-trained model file if needed.
|
|
|
|
Params:
|
|
pretrained_model_name: either:
|
|
- a str with the name of a pre-trained model to load selected in the list of:
|
|
. `transfo-xl`
|
|
- a path or url to a pretrained model archive containing:
|
|
. `transfo_xl_config.json` a configuration file for the model
|
|
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
|
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
|
*inputs, **kwargs: additional input for the specific Bert class
|
|
(ex: num_labels for BertForSequenceClassification)
|
|
"""
|
|
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
|
|
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
|
|
else:
|
|
archive_file = pretrained_model_name
|
|
# redirect to the cache, if necessary
|
|
try:
|
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
|
except FileNotFoundError:
|
|
logger.error(
|
|
"Model name '{}' was not found in model name list ({}). "
|
|
"We assumed '{}' was a path or url but couldn't find any file "
|
|
"associated to this path or url.".format(
|
|
pretrained_model_name,
|
|
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
|
archive_file))
|
|
return None
|
|
if resolved_archive_file == archive_file:
|
|
logger.info("loading archive file {}".format(archive_file))
|
|
else:
|
|
logger.info("loading archive file {} from cache at {}".format(
|
|
archive_file, resolved_archive_file))
|
|
tempdir = None
|
|
if os.path.isdir(resolved_archive_file):
|
|
serialization_dir = resolved_archive_file
|
|
else:
|
|
# Extract archive to temp dir
|
|
tempdir = tempfile.mkdtemp()
|
|
logger.info("extracting archive file {} to temp dir {}".format(
|
|
resolved_archive_file, tempdir))
|
|
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
|
archive.extractall(tempdir)
|
|
serialization_dir = tempdir
|
|
# Load config
|
|
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
|
config = TransfoXLConfig.from_json_file(config_file)
|
|
logger.info("Model config {}".format(config))
|
|
# Instantiate model.
|
|
model = cls(config, *inputs, **kwargs)
|
|
if state_dict is None:
|
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
|
state_dict = torch.load(weights_path)
|
|
|
|
old_keys = []
|
|
new_keys = []
|
|
for key in state_dict.keys():
|
|
new_key = None
|
|
if 'gamma' in key:
|
|
new_key = key.replace('gamma', 'weight')
|
|
if 'beta' in key:
|
|
new_key = key.replace('beta', 'bias')
|
|
if new_key:
|
|
old_keys.append(key)
|
|
new_keys.append(new_key)
|
|
for old_key, new_key in zip(old_keys, new_keys):
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
missing_keys = []
|
|
unexpected_keys = []
|
|
error_msgs = []
|
|
# copy state_dict so _load_from_state_dict can modify it
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
state_dict._metadata = metadata
|
|
|
|
def load(module, prefix=''):
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
module._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
load(child, prefix + name + '.')
|
|
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
|
if len(missing_keys) > 0:
|
|
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
|
model.__class__.__name__, missing_keys))
|
|
if len(unexpected_keys) > 0:
|
|
logger.info("Weights from pretrained model not used in {}: {}".format(
|
|
model.__class__.__name__, unexpected_keys))
|
|
if len(error_msgs) > 0:
|
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
|
# Add additional embeddings for special tokens if needed
|
|
if num_special_tokens != config.n_special:
|
|
model.set_num_special_tokens(num_special_tokens)
|
|
if tempdir:
|
|
# Clean up temp dir
|
|
shutil.rmtree(tempdir)
|
|
return model
|
|
|
|
|
|
class TransfoXLModel(TransfoXLPreTrainedModel):
|
|
""" Transformer XL model
|
|
From "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
|
by Zihang Dai*, Zhilin Yang*, Yiming Yang, William W. Cohen, Jaime Carbonell,
|
|
Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution)
|
|
|
|
Params:
|
|
config: a TransfoXLConfig class instance with the configuration to build a new model
|
|
|
|
Inputs:
|
|
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
|
|
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
|
|
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
|
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
|
|
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
|
|
You can use it to add a third embedding (the previous two being the word and position embeddings)
|
|
to each token in the sentence.
|
|
|
|
Outputs:
|
|
`hidden_states`: the encoded-hidden-states at the top of the model
|
|
as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size]
|
|
(or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
|
|
|
|
Example usage:
|
|
```python
|
|
# Already been converted into BPE token ids
|
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
|
|
|
config = modeling_transfo_xl.TransfoXLConfig()
|
|
|
|
model = modeling_transfo_xl.TransfoXLModel(config)
|
|
hidden_states = model(input_ids)
|
|
```
|
|
"""
|
|
def __init__(self, config):
|
|
super(TransfoXLModel, self).__init__(config)
|
|
self.transformer = MemTransformerLM(**config.to_dict())
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
|
return self.transformer(input_ids, position_ids, token_type_ids)
|