mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
adding TF 2.0 adaptive softmax with logits + loss outputs
This commit is contained in:
parent
39c38b2ea0
commit
65c49bb27e
@ -455,6 +455,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# def call(self, input_ids, attention_mask=None, token_type_ids=None,
|
||||
# position_ids=None, head_mask=None, training=False):
|
||||
def call(self, inputs, training=False):
|
||||
if not isinstance(inputs, (dict, tuple, list)):
|
||||
input_ids = inputs
|
||||
|
File diff suppressed because it is too large
Load Diff
279
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
Normal file
279
pytorch_transformers/modeling_tf_transfo_xl_utilities.py
Normal file
@ -0,0 +1,279 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Utilities for PyTorch Transformer XL model.
|
||||
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from .modeling_tf_utils import shape_list
|
||||
|
||||
class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
keep_order=False, **kwargs):
|
||||
super(TFAdaptiveSoftmaxMask, self).__init__(**kwargs)
|
||||
|
||||
self.n_token = n_token
|
||||
self.d_embed = d_embed
|
||||
self.d_proj = d_proj
|
||||
|
||||
self.cutoffs = cutoffs + [n_token]
|
||||
self.cutoff_ends = [0] + self.cutoffs
|
||||
self.div_val = div_val
|
||||
|
||||
self.shortlist_size = self.cutoffs[0]
|
||||
self.n_clusters = len(self.cutoffs) - 1
|
||||
self.head_size = self.shortlist_size + self.n_clusters
|
||||
self.keep_order = keep_order
|
||||
|
||||
self.out_layers = []
|
||||
self.out_projs = []
|
||||
|
||||
def build(self, input_shape):
|
||||
if self.n_clusters > 0:
|
||||
self.cluster_weight = self.add_weight(shape=(self.n_clusters, self.d_embed),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='cluster_weight')
|
||||
self.cluster_bias = self.add_weight(shape=(self.n_clusters,),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='cluster_bias')
|
||||
|
||||
if self.div_val == 1:
|
||||
for i in range(len(self.cutoffs)):
|
||||
if self.d_proj != self.d_embed:
|
||||
weight = self.add_weight(shape=(self.d_embed, self.d_proj),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='out_projs_._{}'.format(i))
|
||||
self.out_projs.append(weight)
|
||||
else:
|
||||
self.out_projs.append(None)
|
||||
weight = self.add_weight(shape=(self.n_token, self.d_embed,),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='out_layers_._{}_._weight'.format(i))
|
||||
bias = self.add_weight(shape=(self.n_token,),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='out_layers_._{}_._bias'.format(i))
|
||||
self.out_layers.append((weight, bias))
|
||||
else:
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
||||
d_emb_i = self.d_embed // (self.div_val ** i)
|
||||
|
||||
weight = self.add_weight(shape=(d_emb_i, self.d_proj),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='out_projs_._{}'.format(i))
|
||||
self.out_projs.append(weight)
|
||||
weight = self.add_weight(shape=(r_idx-l_idx, d_emb_i,),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='out_layers_._{}_._weight'.format(i))
|
||||
bias = self.add_weight(shape=(r_idx-l_idx,),
|
||||
initializer='zeros',
|
||||
trainable=True,
|
||||
name='out_layers_._{}_._bias'.format(i))
|
||||
self.out_layers.append((weight, bias))
|
||||
super(TFAdaptiveSoftmaxMask, self).build(input_shape)
|
||||
|
||||
@staticmethod
|
||||
def _logit(x, W, b, proj=None):
|
||||
y = x
|
||||
if proj is not None:
|
||||
y = tf.einsum('ibd,ed->ibe', y, proj)
|
||||
return tf.einsum('ibd,nd->ibn', y, W) + b
|
||||
|
||||
@staticmethod
|
||||
def _gather_logprob(logprob, target):
|
||||
lp_size = tf.shape(logprob)
|
||||
r = tf.range(lp_size[0])
|
||||
idx = tf.stack([r, target], 1)
|
||||
return tf.gather_nd(logprob, idx)
|
||||
|
||||
def call(self, inputs, return_mean=True, training=False):
|
||||
hidden, target = inputs
|
||||
head_logprob = 0
|
||||
if self.n_clusters == 0:
|
||||
softmax_b = tf.get_variable('bias', [n_token], initializer=tf.zeros_initializer())
|
||||
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
|
||||
if target is not None:
|
||||
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
|
||||
out = tf.nn.log_softmax(output, axis=-1)
|
||||
else:
|
||||
hidden_sizes = shape_list(hidden)
|
||||
out = []
|
||||
loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
if target is not None:
|
||||
mask = (target >= l_idx) & (target < r_idx)
|
||||
mask_idx = tf.where(mask)
|
||||
cur_target = tf.boolean_mask(target, mask) - l_idx
|
||||
|
||||
if self.div_val == 1:
|
||||
cur_W = self.out_layers[0][0][l_idx:r_idx]
|
||||
cur_b = self.out_layers[0][1][l_idx:r_idx]
|
||||
else:
|
||||
cur_W = self.out_layers[i][0]
|
||||
cur_b = self.out_layers[i][1]
|
||||
|
||||
if i == 0:
|
||||
cur_W = tf.concat([cur_W, self.cluster_weight], 0)
|
||||
cur_b = tf.concat([cur_b, self.cluster_bias], 0)
|
||||
|
||||
head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
|
||||
head_logprob = tf.nn.log_softmax(head_logit)
|
||||
out.append(head_logprob[..., :self.cutoffs[0]])
|
||||
if target is not None:
|
||||
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
||||
cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
|
||||
else:
|
||||
tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i])
|
||||
tail_logprob = tf.nn.log_softmax(tail_logit)
|
||||
cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
|
||||
logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob
|
||||
out.append(logprob_i)
|
||||
if target is not None:
|
||||
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
||||
cur_tail_logprob = tf.boolean_mask(tail_logprob, mask)
|
||||
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
|
||||
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
|
||||
if target is not None:
|
||||
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(tf.shape(loss), dtype=tf.int64))
|
||||
out = tf.concat(out, axis=-1)
|
||||
|
||||
if target is not None:
|
||||
if return_mean:
|
||||
loss = tf.reduce_mean(loss)
|
||||
# Add the training-time loss value to the layer using `self.add_loss()`.
|
||||
self.add_loss(loss)
|
||||
|
||||
# Log the loss as a metric (we could log arbitrary metrics,
|
||||
# including different metrics for training and inference.
|
||||
self.add_metric(loss, name=self.name, aggregation='mean' if return_mean else '')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
||||
params, tie_projs,
|
||||
initializer=None, proj_initializer=None,
|
||||
div_val=1, perms=None, proj_same_dim=True,
|
||||
scope='adaptive_softmax',
|
||||
**kwargs):
|
||||
def _logit(x, W, b, proj):
|
||||
y = x
|
||||
if x.shape.ndims == 3:
|
||||
if proj is not None:
|
||||
y = tf.einsum('ibd,ed->ibe', y, proj)
|
||||
return tf.einsum('ibd,nd->ibn', y, W) + b
|
||||
else:
|
||||
if proj is not None:
|
||||
y = tf.einsum('id,ed->ie', y, proj)
|
||||
return tf.einsum('id,nd->in', y, W) + b
|
||||
|
||||
params_W, params_projs = params[0], params[1]
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
if len(cutoffs) == 0:
|
||||
softmax_b = tf.get_variable('bias', [n_token],
|
||||
initializer=tf.zeros_initializer())
|
||||
output = _logit(hidden, params_W, softmax_b, params_projs)
|
||||
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
||||
logits=output)
|
||||
nll = tf.reduce_mean(nll)
|
||||
else:
|
||||
total_loss, total_cnt = 0, 0
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
|
||||
if div_val == 1:
|
||||
cur_W = params_W[l_idx: r_idx]
|
||||
else:
|
||||
cur_W = params_W[i]
|
||||
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
||||
initializer=tf.zeros_initializer())
|
||||
if tie_projs[i]:
|
||||
if div_val == 1:
|
||||
cur_proj = params_projs
|
||||
else:
|
||||
cur_proj = params_projs[i]
|
||||
else:
|
||||
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
||||
cur_proj = None
|
||||
else:
|
||||
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
|
||||
if i == 0:
|
||||
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
||||
initializer=tf.zeros_initializer())
|
||||
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
||||
initializer=tf.zeros_initializer())
|
||||
cur_W = tf.concat([cur_W, cluster_W], 0)
|
||||
cur_b = tf.concat([cur_b, cluster_b], 0)
|
||||
|
||||
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
||||
|
||||
head_target = kwargs.get("head_target")
|
||||
head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=head_target,
|
||||
logits=head_logit)
|
||||
|
||||
masked_loss = head_nll * perms[i]
|
||||
total_loss += tf.reduce_sum(masked_loss)
|
||||
total_cnt += tf.reduce_sum(perms[i])
|
||||
|
||||
# head_logprob = tf.nn.log_softmax(head_logit)
|
||||
|
||||
# final_logprob = head_logprob * perms[i][:, :, None]
|
||||
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
|
||||
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
|
||||
# total_cnt += tf.reduce_sum(perms[i])
|
||||
else:
|
||||
cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i])
|
||||
|
||||
cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
|
||||
tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)
|
||||
|
||||
tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx),
|
||||
perms[i])
|
||||
tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=tf.to_int32(tail_target),
|
||||
logits=tail_logit)
|
||||
|
||||
sum_nll = cur_head_nll + tail_nll
|
||||
mask = tf.reduce_sum(perms[i], [0, 1])
|
||||
|
||||
masked_loss = sum_nll * mask
|
||||
total_loss += tf.reduce_sum(masked_loss)
|
||||
total_cnt += tf.reduce_sum(mask)
|
||||
|
||||
nll = total_loss / total_cnt
|
||||
|
||||
return nll
|
@ -261,8 +261,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
self.ffns = []
|
||||
self.layer_norm2 = []
|
||||
# if self.is_decoder:
|
||||
# self.layer_norm15 = tf.keras.layers.LayerList()
|
||||
# self.encoder_attn = tf.keras.layers.LayerList()
|
||||
# self.layer_norm15 = []
|
||||
# self.encoder_attn = []
|
||||
|
||||
for i in range(self.n_layers):
|
||||
self.attentions.append(TFMultiHeadAttention(self.n_heads, self.dim, config=config, name='attentions_._{}'.format(i)))
|
||||
|
@ -229,102 +229,11 @@ class PositionwiseFF(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class MultiHeadAttn(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
||||
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
|
||||
super(MultiHeadAttn, self).__init__()
|
||||
|
||||
self.output_attentions = output_attentions
|
||||
self.n_head = n_head
|
||||
self.d_model = d_model
|
||||
self.d_head = d_head
|
||||
self.dropout = dropout
|
||||
|
||||
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
|
||||
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
|
||||
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.dropatt = nn.Dropout(dropatt)
|
||||
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self.scale = 1 / (d_head ** 0.5)
|
||||
|
||||
self.pre_lnorm = pre_lnorm
|
||||
|
||||
if r_r_bias is None or r_w_bias is None: # Biases are not shared
|
||||
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
|
||||
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
|
||||
else:
|
||||
self.r_r_bias = r_r_bias
|
||||
self.r_w_bias = r_w_bias
|
||||
|
||||
def forward(self, h, attn_mask=None, mems=None, head_mask=None):
|
||||
##### multihead attention
|
||||
# [hlen x bsz x n_head x d_head]
|
||||
|
||||
if mems is not None:
|
||||
c = torch.cat([mems, h], 0)
|
||||
else:
|
||||
c = h
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### layer normalization
|
||||
c = self.layer_norm(c)
|
||||
|
||||
head_q = self.q_net(h)
|
||||
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
|
||||
|
||||
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
|
||||
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
||||
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
|
||||
attn_score.mul_(self.scale)
|
||||
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||
attn_mask = (attn_mask == 1) # Switch to bool
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = F.softmax(attn_score, dim=1)
|
||||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_prob = attn_prob * head_mask
|
||||
|
||||
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
|
||||
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
##### linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
outputs = [h + attn_out]
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
outputs = [self.layer_norm(h + attn_out)]
|
||||
|
||||
if self.output_attentions:
|
||||
outputs.append(attn_prob)
|
||||
|
||||
return outputs
|
||||
|
||||
class RelMultiHeadAttn(nn.Module):
|
||||
class RelPartialLearnableMultiHeadAttn(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
||||
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
|
||||
r_r_bias=None, r_w_bias=None, output_attentions=False):
|
||||
super(RelMultiHeadAttn, self).__init__()
|
||||
super(RelPartialLearnableMultiHeadAttn, self).__init__()
|
||||
|
||||
self.output_attentions = output_attentions
|
||||
self.n_head = n_head
|
||||
@ -351,36 +260,9 @@ class RelMultiHeadAttn(nn.Module):
|
||||
self.r_r_bias = r_r_bias
|
||||
self.r_w_bias = r_w_bias
|
||||
|
||||
def _parallelogram_mask(self, h, w, left=False):
|
||||
mask = torch.ones((h, w)).byte()
|
||||
m = min(h, w)
|
||||
mask[:m,:m] = torch.triu(mask[:m,:m])
|
||||
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
|
||||
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
|
||||
|
||||
if left:
|
||||
return mask
|
||||
else:
|
||||
return mask.flip(0)
|
||||
|
||||
def _shift(self, x, qlen, klen, mask, left=False):
|
||||
if qlen > 1:
|
||||
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
|
||||
device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
|
||||
|
||||
if left:
|
||||
mask = mask.flip(1)
|
||||
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
|
||||
else:
|
||||
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
|
||||
|
||||
x = x_padded.masked_select(mask[:,:,None,None]) \
|
||||
.view(qlen, klen, x.size(2), x.size(3))
|
||||
|
||||
return x
|
||||
|
||||
def _rel_shift(self, x, zero_triu=False):
|
||||
def _rel_shift(self, x):
|
||||
zero_pad_shape = (x.size(0), 1) + x.size()[2:]
|
||||
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||
@ -390,21 +272,8 @@ class RelMultiHeadAttn(nn.Module):
|
||||
|
||||
x = x_padded[1:].view_as(x)
|
||||
|
||||
if zero_triu:
|
||||
ones = torch.ones((x.size(0), x.size(1)))
|
||||
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, w, r, attn_mask=None, mems=None):
|
||||
raise NotImplementedError
|
||||
|
||||
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
|
||||
|
||||
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
|
||||
|
||||
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
|
||||
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
|
||||
|
||||
@ -488,138 +357,6 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
|
||||
return outputs
|
||||
|
||||
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
|
||||
# r_emb: [klen, n_head, d_head], used for term B
|
||||
# r_w_bias: [n_head, d_head], used for term C
|
||||
# r_bias: [klen, n_head], used for term D
|
||||
|
||||
qlen, bsz = w.size(0), w.size(1)
|
||||
|
||||
if mems is not None:
|
||||
cat = torch.cat([mems, w], 0)
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(cat))
|
||||
else:
|
||||
w_heads = self.qkv_net(cat)
|
||||
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
||||
|
||||
w_head_q = w_head_q[-qlen:]
|
||||
else:
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(w))
|
||||
else:
|
||||
w_heads = self.qkv_net(w)
|
||||
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
||||
|
||||
klen = w_head_k.size(0)
|
||||
|
||||
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
|
||||
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
|
||||
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
|
||||
|
||||
if klen > r_emb.size(0):
|
||||
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
|
||||
r_emb = torch.cat([r_emb_pad, r_emb], 0)
|
||||
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
|
||||
r_bias = torch.cat([r_bias_pad, r_bias], 0)
|
||||
else:
|
||||
r_emb = r_emb[-klen:]
|
||||
r_bias = r_bias[-klen:]
|
||||
|
||||
#### compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
||||
|
||||
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
|
||||
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
|
||||
BD = self._rel_shift(B_ + D_)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_score = AC + BD
|
||||
attn_score.mul_(self.scale)
|
||||
|
||||
#### compute attention probability
|
||||
if attn_mask is not None and torch.sum(attn_mask).item():
|
||||
attn_mask = (attn_mask == 1) # Switch to bool
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = F.softmax(attn_score, dim=1)
|
||||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_prob = attn_prob * head_mask
|
||||
|
||||
#### compute attention vector
|
||||
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
##### linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
outputs = [w + attn_out]
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
outputs = [self.layer_norm(w + attn_out)]
|
||||
|
||||
if self.output_attentions:
|
||||
outputs.append(attn_prob)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
|
||||
|
||||
attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
|
||||
mems=mems, head_mask=head_mask)
|
||||
ff_output = self.pos_ff(attn_outputs[0])
|
||||
|
||||
outputs = [ff_output] + attn_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
class RelLearnableDecoderLayer(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
|
||||
**kwargs):
|
||||
super(RelLearnableDecoderLayer, self).__init__()
|
||||
|
||||
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
|
||||
**kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
|
||||
|
||||
attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
|
||||
attn_mask=dec_attn_mask,
|
||||
mems=mems, head_mask=head_mask)
|
||||
ff_output = self.pos_ff(attn_outputs[0])
|
||||
|
||||
outputs = [ff_output] + attn_outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
class RelPartialLearnableDecoderLayer(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
|
||||
@ -643,7 +380,6 @@ class RelPartialLearnableDecoderLayer(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
class AdaptiveEmbedding(nn.Module):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
sample_softmax=False):
|
||||
@ -767,9 +503,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
||||
if hasattr(m, 'r_bias'):
|
||||
self._init_bias(m.r_bias)
|
||||
|
||||
def set_num_special_tokens(self, num_special_tokens):
|
||||
pass
|
||||
|
||||
|
||||
TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in
|
||||
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_
|
||||
@ -882,43 +615,16 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
r_r_bias=None if config.untie_r else self.r_r_bias,
|
||||
output_attentions=self.output_attentions)
|
||||
)
|
||||
elif config.attn_type == 1: # learnable embeddings
|
||||
for i in range(config.n_layer):
|
||||
self.layers.append(
|
||||
RelLearnableDecoderLayer(
|
||||
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
|
||||
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
|
||||
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
|
||||
r_w_bias=None if config.untie_r else self.r_w_bias,
|
||||
r_r_bias=None if config.untie_r else self.r_r_bias,
|
||||
output_attentions=self.output_attentions)
|
||||
)
|
||||
elif config.attn_type in [2, 3]: # absolute embeddings
|
||||
for i in range(config.n_layer):
|
||||
self.layers.append(
|
||||
DecoderLayer(
|
||||
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
|
||||
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
|
||||
r_w_bias=None if config.untie_r else self.r_w_bias,
|
||||
r_r_bias=None if config.untie_r else self.r_r_bias,
|
||||
output_attentions=self.output_attentions)
|
||||
)
|
||||
else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
|
||||
raise NotImplementedError # Removed them to avoid maintaining dead code
|
||||
|
||||
self.same_length = config.same_length
|
||||
self.clamp_len = config.clamp_len
|
||||
|
||||
if self.attn_type == 0: # default attention
|
||||
self.pos_emb = PositionalEmbedding(self.d_model)
|
||||
elif self.attn_type == 1: # learnable
|
||||
self.r_emb = nn.Parameter(torch.FloatTensor(
|
||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||
self.r_bias = nn.Parameter(torch.FloatTensor(
|
||||
self.n_layer, self.max_klen, self.n_head))
|
||||
elif self.attn_type == 2: # absolute standard
|
||||
self.pos_emb = PositionalEmbedding(self.d_model)
|
||||
elif self.attn_type == 3: # absolute deeper SA
|
||||
self.r_emb = nn.Parameter(torch.FloatTensor(
|
||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||
else: # learnable embeddings and absolute embeddings
|
||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@ -973,8 +679,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
|
||||
return new_mems
|
||||
|
||||
def _forward(self, dec_inp, mems=None, head_mask=None):
|
||||
qlen, bsz = dec_inp.size()
|
||||
def forward(self, input_ids, mems=None, head_mask=None):
|
||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||
input_ids = input_ids.transpose(0, 1).contiguous()
|
||||
|
||||
if mems is None:
|
||||
mems = self.init_mems(input_ids)
|
||||
|
||||
qlen, bsz = input_ids.size()
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -991,7 +704,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.n_layer
|
||||
|
||||
word_emb = self.word_emb(dec_inp)
|
||||
word_emb = self.word_emb(input_ids)
|
||||
|
||||
mlen = mems[0].size(0) if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
@ -1028,64 +741,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
core_out = layer_outputs[0]
|
||||
if self.output_attentions:
|
||||
attentions.append(layer_outputs[1])
|
||||
elif self.attn_type == 1: # learnable
|
||||
core_out = self.drop(word_emb)
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out)
|
||||
if self.clamp_len > 0:
|
||||
r_emb = self.r_emb[i][-self.clamp_len :]
|
||||
r_bias = self.r_bias[i][-self.clamp_len :]
|
||||
else:
|
||||
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
|
||||
|
||||
mems_i = None if mems is None else mems[i]
|
||||
layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
|
||||
r_bias, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i, head_mask=head_mask[i])
|
||||
core_out = layer_outputs[0]
|
||||
if self.output_attentions:
|
||||
attentions.append(layer_outputs[1])
|
||||
elif self.attn_type == 2: # absolute
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
pos_seq.clamp_(max=self.clamp_len)
|
||||
pos_emb = self.pos_emb(pos_seq)
|
||||
|
||||
core_out = self.drop(word_emb + pos_emb[-qlen:])
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out)
|
||||
mems_i = None if mems is None else mems[i]
|
||||
if mems_i is not None and i == 0:
|
||||
mems_i += pos_emb[:mlen]
|
||||
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i, head_mask=head_mask[i])
|
||||
core_out = layer_outputs[0]
|
||||
if self.output_attentions:
|
||||
attentions.append(layer_outputs[1])
|
||||
elif self.attn_type == 3:
|
||||
core_out = self.drop(word_emb)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hids.append(core_out)
|
||||
mems_i = None if mems is None else mems[i]
|
||||
if mems_i is not None and mlen > 0:
|
||||
cur_emb = self.r_emb[i][:-qlen]
|
||||
cur_size = cur_emb.size(0)
|
||||
if cur_size < mlen:
|
||||
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
|
||||
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
|
||||
else:
|
||||
cur_emb = cur_emb[-mlen:]
|
||||
mems_i += cur_emb.view(mlen, 1, -1)
|
||||
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
|
||||
|
||||
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i, head_mask=head_mask[i])
|
||||
core_out = layer_outputs[0]
|
||||
if self.output_attentions:
|
||||
attentions.append(layer_outputs[1])
|
||||
else: # learnable embeddings and absolute embeddings
|
||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||
|
||||
core_out = self.drop(core_out)
|
||||
|
||||
@ -1102,16 +759,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
|
||||
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||
outputs.append(attentions)
|
||||
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
|
||||
|
||||
def forward(self, input_ids, mems=None, head_mask=None):
|
||||
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
|
||||
# so we transpose here from shape [bsz, len] to shape [len, bsz]
|
||||
input_ids = input_ids.transpose(0, 1).contiguous()
|
||||
|
||||
if mems is None:
|
||||
mems = self.init_mems(input_ids)
|
||||
outputs = self._forward(input_ids, mems=mems, head_mask=head_mask)
|
||||
|
||||
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
|
||||
|
||||
|
@ -131,10 +131,14 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFBertModel(config=config)
|
||||
# inputs = {'input_ids': input_ids,
|
||||
# 'attention_mask': input_mask,
|
||||
# 'token_type_ids': token_type_ids}
|
||||
# sequence_output, pooled_output = model(**inputs)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
217
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
Normal file
217
pytorch_transformers/tests/modeling_tf_transfo_xl_test.py
Normal file
@ -0,0 +1,217 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
from pytorch_transformers import TransfoXLConfig, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from pytorch_transformers.modeling_tf_transfo_xl import (TFTransfoXLModel,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
|
||||
class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
class TFTransfoXLModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
mem_len=30,
|
||||
clamp_len=15,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
cutoffs=[10, 50, 80],
|
||||
hidden_size=32,
|
||||
d_embed=32,
|
||||
num_attention_heads=4,
|
||||
d_head=8,
|
||||
d_inner=128,
|
||||
div_val=2,
|
||||
num_hidden_layers=5,
|
||||
scope=None,
|
||||
seed=1,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.mem_len = mem_len
|
||||
self.key_len = seq_length + mem_len
|
||||
self.clamp_len = clamp_len
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.cutoffs = cutoffs
|
||||
self.hidden_size = hidden_size
|
||||
self.d_embed = d_embed
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_head = d_head
|
||||
self.d_inner = d_inner
|
||||
self.div_val = div_val
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.scope = scope
|
||||
self.seed = seed
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
lm_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = TransfoXLConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
mem_len=self.mem_len,
|
||||
clamp_len=self.clamp_len,
|
||||
cutoffs=self.cutoffs,
|
||||
d_model=self.hidden_size,
|
||||
d_embed=self.d_embed,
|
||||
n_head=self.num_attention_heads,
|
||||
d_head=self.d_head,
|
||||
d_inner=self.d_inner,
|
||||
div_val=self.div_val,
|
||||
n_layer=self.num_hidden_layers)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, lm_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
tf.random.set_seed(self.seed)
|
||||
|
||||
def create_and_check_transfo_xl_model(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
model = TFTransfoXLModel(config)
|
||||
|
||||
hidden_states_1, mems_1 = model(input_ids_1)
|
||||
|
||||
inputs = {'input_ids': input_ids_2,
|
||||
'mems': mems_1}
|
||||
|
||||
hidden_states_2, mems_2 = model(inputs)
|
||||
|
||||
result = {
|
||||
"hidden_states_1": hidden_states_1.numpy(),
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"hidden_states_2": hidden_states_2.numpy(),
|
||||
"mems_2": [mem.numpy() for mem in mems_2],
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
|
||||
def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
|
||||
model = TFTransfoXLLMHeadModel(config)
|
||||
|
||||
lm_logits_1, mems_1 = model(input_ids_1)
|
||||
|
||||
inputs = {'input_ids': input_ids_1,
|
||||
'labels': lm_labels}
|
||||
_, mems_1 = model(inputs)
|
||||
|
||||
lm_logits_2, mems_2 = model([input_ids_2, mems_1])
|
||||
|
||||
inputs = {'input_ids': input_ids_1,
|
||||
'mems': mems_1,
|
||||
'labels': lm_labels}
|
||||
|
||||
_, mems_2 = model(inputs)
|
||||
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"lm_logits_1": lm_logits_1.numpy(),
|
||||
"mems_2": [mem.numpy() for mem in mems_2],
|
||||
"lm_logits_2": lm_logits_2.numpy(),
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_2"]),
|
||||
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFTransfoXLModelTest.TFTransfoXLModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_transfo_xl_model(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_transfo_xl_model(*config_and_inputs)
|
||||
|
||||
def test_transfo_xl_lm_head(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user