# coding=utf-8 # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ TF 2.0 Transformer XL model. """ from __future__ import absolute_import, division, print_function, unicode_literals import os import json import math import logging import collections import sys from io import open import numpy as np import tensorflow as tf from .configuration_transfo_xl import TransfoXLConfig from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask from .file_utils import add_start_docstrings from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model logger = logging.getLogger(__name__) TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-tf_model.h5", } def load_transfo_xl_pt_weights_in_tf2(tf_model, pytorch_checkpoint_path): # build the network inputs_list = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] tf_inputs = tf.constant(inputs_list) tfo = tf_model(tf_inputs, training=False) return load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=tf_inputs) class TFPositionalEmbedding(tf.keras.layers.Layer): def __init__(self, demb, **kwargs): super(TFPositionalEmbedding, self).__init__(**kwargs) self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb)) def call(self, pos_seq, bsz=None): sinusoid_inp = tf.einsum('i,j->ij', pos_seq, self.inv_freq) pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) if bsz is not None: return tf.tile(pos_emb[:, None, :], [1, bsz, 1]) else: return pos_emb[:, None, :] class TFPositionwiseFF(tf.keras.layers.Layer): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, **kwargs): super(TFPositionwiseFF, self).__init__(**kwargs) self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.layer_1 = tf.keras.layers.Dense(d_inner, activation=tf.nn.relu, name='CoreNet_._0') self.drop_1 = tf.keras.layers.Dropout(dropout) self.layer_2 = tf.keras.layers.Dense(d_model, name='CoreNet_._2') self.drop_2 = tf.keras.layers.Dropout(dropout) self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm') self.pre_lnorm = pre_lnorm def call(self, inp, training=False): if self.pre_lnorm: ##### layer normalization + positionwise feed-forward core_out = self.layer_norm(inp) core_out = self.layer_1(core_out) core_out = self.drop_1(core_out, training=training) core_out = self.layer_2(core_out) core_out = self.drop_2(core_out, training=training) ##### residual connection output = core_out + inp else: ##### positionwise feed-forward core_out = self.layer_1(inp) core_out = self.drop_1(core_out, training=training) core_out = self.layer_2(core_out) core_out = self.drop_2(core_out, training=training) ##### residual connection + layer normalization output = self.layer_norm(inp + core_out) return output class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False, **kwargs): super(TFRelPartialLearnableMultiHeadAttn, self).__init__(**kwargs) self.output_attentions = output_attentions self.n_head = n_head self.d_model = d_model self.d_head = d_head self.dropout = dropout self.qkv_net = tf.keras.layers.Dense(3 * n_head * d_head, use_bias=False, name='qkv_net') self.drop = tf.keras.layers.Dropout(dropout) self.dropatt = tf.keras.layers.Dropout(dropatt) self.o_net = tf.keras.layers.Dense(d_model, use_bias=False, name='o_net') self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name='layer_norm') self.scale = 1 / (d_head ** 0.5) self.pre_lnorm = pre_lnorm if r_r_bias is not None and r_w_bias is not None: # Biases are shared self.r_r_bias = r_r_bias self.r_w_bias = r_w_bias else: self.r_r_bias = None self.r_w_bias = None self.r_net = tf.keras.layers.Dense(self.n_head * self.d_head, use_bias=False, name='r_net') def build(self, input_shape): if self.r_r_bias is None or self.r_w_bias is None: # Biases are not shared self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head), trainable=True, name='r_r_bias') self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head), trainable=True, name='r_w_bias') super(TFRelPartialLearnableMultiHeadAttn, self).build(input_shape) def _rel_shift(self, x): x_size = shape_list(x) x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]]) x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) x = tf.reshape(x, x_size) return x def call(self, inputs, training=False): w, r, attn_mask, mems, head_mask = inputs qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] if mems is not None: cat = tf.concat([mems, w], 0) if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(cat)) else: w_heads = self.qkv_net(cat) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1) w_head_q = w_head_q[-qlen:] else: if self.pre_lnorm: w_heads = self.qkv_net(self.layer_norm(w)) else: w_heads = self.qkv_net(w) r_head_k = self.r_net(r) w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1) klen = shape_list(w_head_k)[0] w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head #### compute attention score rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k) # qlen x klen x bsz x n_head rr_head_q = w_head_q + self.r_r_bias BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k) # qlen x klen x bsz x n_head BD = self._rel_shift(BD) # [qlen x klen x bsz x n_head] attn_score = AC + BD attn_score = attn_score * self.scale #### compute attention probability if attn_mask is not None: attn_mask_t = attn_mask[:, :, None, None] attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t # [qlen x klen x bsz x n_head] attn_prob = tf.nn.softmax(attn_score, axis=1) attn_prob = self.dropatt(attn_prob, training=training) # Mask heads if we want to if head_mask is not None: attn_prob = attn_prob * head_mask #### compute attention vector attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v) # [qlen x bsz x n_head x d_head] attn_vec_sizes = shape_list(attn_vec) attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head)) ##### linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out, training=training) if self.pre_lnorm: ##### residual connection outputs = [w + attn_out] else: ##### residual connection + layer normalization outputs = [self.layer_norm(w + attn_out)] if self.output_attentions: outputs.append(attn_prob) return outputs class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): def __init__(self, n_head, d_model, d_head, d_inner, dropout, tgt_len=None, ext_len=None, mem_len=None, dropatt=0., pre_lnorm=False, r_w_bias=None, r_r_bias=None, output_attentions=False, **kwargs): super(TFRelPartialLearnableDecoderLayer, self).__init__(**kwargs) self.dec_attn = TFRelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, r_w_bias=r_w_bias, r_r_bias=r_r_bias, output_attentions=output_attentions, name='dec_attn') self.pos_ff = TFPositionwiseFF(d_model, d_inner, dropout, pre_lnorm=pre_lnorm, name='pos_ff') def call(self, inputs, training=False): dec_inp, r, dec_attn_mask, mems, head_mask = inputs attn_outputs = self.dec_attn([dec_inp, r, dec_attn_mask, mems, head_mask], training=training) ff_output = self.pos_ff(attn_outputs[0], training=training) outputs = [ff_output] + attn_outputs[1:] return outputs class TFAdaptiveEmbedding(tf.keras.layers.Layer): def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False, **kwargs): super(TFAdaptiveEmbedding, self).__init__(**kwargs) self.n_token = n_token self.d_embed = d_embed self.cutoffs = cutoffs + [n_token] self.div_val = div_val self.d_proj = d_proj self.emb_scale = d_proj ** 0.5 self.cutoff_ends = [0] + self.cutoffs self.emb_layers = [] self.emb_projs = [] if div_val == 1: raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint else: for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] d_emb_i = d_embed // (div_val ** i) self.emb_layers.append(tf.keras.layers.Embedding(r_idx-l_idx, d_emb_i, name='emb_layers_._{}'.format(i))) def build(self, input_shape): for i in range(len(self.cutoffs)): d_emb_i = self.d_embed // (self.div_val ** i) self.emb_projs.append(self.add_weight(shape=(d_emb_i, self.d_proj), trainable=True, name='emb_projs._{}'.format(i))) super(TFAdaptiveEmbedding, self).build(input_shape) def call(self, inp): if self.div_val == 1: raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint else: inp_flat = tf.reshape(inp, (-1,)) emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj]) for i in range(len(self.cutoffs)): l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) inp_i = tf.boolean_mask(inp_flat, mask_i) - l_idx emb_i = self.emb_layers[i](inp_i) emb_i = tf.einsum('id,de->ie', emb_i, self.emb_projs[i]) mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64) emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(tf.shape(emb_flat), dtype=tf.int64)) embed_shape = shape_list(inp) + [self.d_proj] embed = tf.reshape(emb_flat, embed_shape) embed *= self.emb_scale return embed class TFTransfoXLMainLayer(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super(TFTransfoXLMainLayer, self).__init__(**kwargs) self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.n_token = config.n_token self.d_embed = config.d_embed self.d_model = config.d_model self.n_head = config.n_head self.d_head = config.d_head self.untie_r = config.untie_r self.word_emb = TFAdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name='word_emb') self.drop = tf.keras.layers.Dropout(config.dropout) self.n_layer = config.n_layer self.tgt_len = config.tgt_len self.mem_len = config.mem_len self.ext_len = config.ext_len self.max_klen = config.tgt_len + config.ext_len + config.mem_len self.attn_type = config.attn_type self.layers = [] if config.attn_type == 0: # the default attention for i in range(config.n_layer): self.layers.append( TFRelPartialLearnableDecoderLayer( config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if self.untie_r else self.r_w_bias, r_r_bias=None if self.untie_r else self.r_r_bias, output_attentions=self.output_attentions, name='layers_._{}'.format(i)) ) else: # learnable embeddings and absolute embeddings raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint self.same_length = config.same_length self.clamp_len = config.clamp_len if self.attn_type == 0: # default attention self.pos_emb = TFPositionalEmbedding(self.d_model, name='pos_emb') else: # learnable embeddings and absolute embeddings raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint def build(self, input_shape): if not self.untie_r: self.r_w_bias = self.add_weight(shape=(self.n_head, self.d_head), initializer='zeros', trainable=True, name='r_w_bias') self.r_r_bias = self.add_weight(shape=(self.n_head, self.d_head), initializer='zeros', trainable=True, name='r_r_bias') super(TFTransfoXLMainLayer, self).build(input_shape) def _resize_token_embeddings(self, new_num_tokens): return self.word_emb def backward_compatible(self): self.sample_softmax = -1 def reset_length(self, tgt_len, ext_len, mem_len): self.tgt_len = tgt_len self.mem_len = mem_len self.ext_len = ext_len def _prune_heads(self, heads): raise NotImplementedError def init_mems(self, data): if self.mem_len > 0: mems = [] for i in range(self.n_layer): empty = tf.zeros([self.mem_len, shape_list(data)[1], self.d_model]) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, qlen, mlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), 'len(hids) != len(mems)' # There are `mlen + qlen` steps that can be cached into mems # For the next step, the last `ext_len` of the `qlen` tokens # will be used as the extended context. Hence, we only cache # the tokens from `mlen + qlen - self.ext_len - self.mem_len` # to `mlen + qlen - self.ext_len`. new_mems = [] end_idx = mlen + max(0, qlen - 0 - self.ext_len) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): cat = tf.concat([mems[i], hids[i]], axis=0) tf.stop_gradient(cat) new_mems.append(cat[beg_idx:end_idx]) return new_mems def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): input_ids = inputs mems, head_mask = None, None elif isinstance(inputs, (tuple, list)): input_ids = inputs[0] mems = inputs[1] if len(inputs) > 1 else None head_mask = inputs[2] if len(inputs) > 2 else None assert len(inputs) <= 3, "Too many inputs." else: input_ids = inputs.get('input_ids') mems = inputs.get('mems', None) head_mask = inputs.get('head_mask', None) assert len(inputs) <= 3, "Too many inputs." # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library # so we transpose here from shape [bsz, len] to shape [len, bsz] input_ids = tf.transpose(input_ids, perm=(1, 0)) if mems is None: mems = self.init_mems(input_ids) qlen, bsz = shape_list(input_ids) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] if not head_mask is None: raise NotImplementedError else: head_mask = [None] * self.n_layer word_emb = self.word_emb(input_ids) mlen = shape_list(mems[0])[0] if mems is not None else 0 klen = mlen + qlen attn_mask = tf.ones([qlen, qlen]) mask_u = tf.linalg.band_part(attn_mask, 0, -1) mask_dia = tf.linalg.band_part(attn_mask, 0, 0) attn_mask_pad = tf.zeros([qlen, mlen]) dec_attn_mask = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) if self.same_length: mask_l = tf.linalg.band_part(attn_mask, -1, 0) dec_attn_mask = tf.concat([dec_attn_mask[:, :qlen] + mask_l - mask_dia, dec_attn_mask[:, qlen:]], 1) # ::: PyTorch masking code for reference ::: # if self.same_length: # all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) # mask_len = klen - self.mem_len # if mask_len > 0: # mask_shift_len = qlen - mask_len # else: # mask_shift_len = qlen # dec_attn_mask = (torch.triu(all_ones, 1+mlen) # + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 # else: # dec_attn_mask = torch.triu( # word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1+mlen)[:,:,None] hids = [] attentions = [] if self.attn_type == 0: # default pos_seq = tf.range(klen-1, -1, -1.0) if self.clamp_len > 0: pos_seq = tf.minimum(pos_seq, self.clamp_len) pos_emb = self.pos_emb(pos_seq) core_out = self.drop(word_emb, training=training) pos_emb = self.drop(pos_emb, training=training) for i, layer in enumerate(self.layers): hids.append(core_out) mems_i = None if mems is None else mems[i] layer_outputs = layer([core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i]], training=training) core_out = layer_outputs[0] if self.output_attentions: attentions.append(layer_outputs[1]) else: # learnable embeddings and absolute embeddings raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint core_out = self.drop(core_out, training=training) new_mems = self._update_mems(hids, mems, mlen, qlen) # We transpose back here to shape [bsz, len, hidden_dim] outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems] if self.output_hidden_states: # Add last layer and transpose to library standard shape [bsz, len, hidden_dim] hids.append(core_out) hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids) outputs.append(hids) if self.output_attentions: # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) outputs.append(attentions) return outputs # last hidden state, new_mems, (all hidden states), (all attentions) class TFTransfoXLPreTrainedModel(TFPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ config_class = TransfoXLConfig pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP load_pt_weights = load_transfo_xl_pt_weights_in_tf2 base_model_prefix = "transformer" TRANSFO_XL_START_DOCSTRING = r""" The Transformer-XL model was proposed in `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`_ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. It's a causal (uni-directional) transformer with relative positioning (sinusoïdal) embeddings which can reuse previously computed hidden-states to attend to longer context (memory). This model also uses adaptive softmax inputs and outputs (tied). This model is a PyTorch `torch.tf.keras.layers.Layer`_ sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. .. _`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`: https://arxiv.org/abs/1901.02860 .. _`torch.tf.keras.layers.Layer`: https://pytorch.org/docs/stable/nn.html#module Parameters: config (:class:`~pytorch_transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ TRANSFO_XL_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. Transformer-XL is a model with relative position embeddings so you can either pad the inputs on the right or on the left. Indices can be obtained using :class:`pytorch_transformers.TransfoXLTokenizer`. See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. **mems**: (`optional`) list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` output below). Can be used to speed up sequential decoding and attend to longer context. **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. """ @add_start_docstrings("The bare Bert Model transformer outputing raw hidden-states without any specific head on top.", TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) class TFTransfoXLModel(TFTransfoXLPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` Sequence of hidden-states at the last layer of the model. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') model = TransfoXLModel.from_pretrained('transfo-xl-wt103') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) last_hidden_states, mems = outputs[:2] """ def __init__(self, config, *inputs, **kwargs): super(TFTransfoXLModel, self).__init__(config, *inputs, **kwargs) self.transformer = TFTransfoXLMainLayer(config, name='transformer') def call(self, inputs, training=False, **kwargs): outputs = self.transformer(inputs, training=training, **kwargs) return outputs @add_start_docstrings("""The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive input embeddings)""", TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): r""" **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Indices are selected in ``[-1, 0, ..., config.vocab_size]`` All labels set to ``-1`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Language modeling loss. **prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). We don't output them when the loss is computed to speedup adaptive softmax decoding. **mems**: list of ``torch.FloatTensor`` (one for each layer): that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103') model = TransfoXLLMHeadModel.from_pretrained('transfo-xl-wt103') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) prediction_scores, mems = outputs[:2] """ def __init__(self, config): super(TFTransfoXLLMHeadModel, self).__init__(config) self.transformer = TFTransfoXLMainLayer(config, name='transformer') self.sample_softmax = config.sample_softmax # use sampled softmax if config.sample_softmax > 0: raise NotImplementedError # use adaptive softmax (including standard softmax) else: self.crit = TFAdaptiveSoftmaxMask(config.n_token, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name='crit') def reset_length(self, tgt_len, ext_len, mem_len): self.transformer.reset_length(tgt_len, ext_len, mem_len) def init_mems(self, data): return self.transformer.init_mems(data) def call(self, inputs, training=False): if not isinstance(inputs, (dict, tuple, list)): input_ids = inputs mems, head_mask, labels = None, None, None elif isinstance(inputs, (tuple, list)): input_ids = inputs[0] mems = inputs[1] if len(inputs) > 1 else None head_mask = inputs[2] if len(inputs) > 2 else None labels = inputs[3] if len(inputs) > 3 else None assert len(inputs) <= 4, "Too many inputs." else: input_ids = inputs.get('input_ids') mems = inputs.get('mems', None) head_mask = inputs.get('head_mask', None) labels = inputs.get('labels', None) assert len(inputs) <= 4, "Too many inputs." bsz, tgt_len = shape_list(input_ids)[:2] transformer_outputs = self.transformer([input_ids, mems, head_mask], training=training) last_hidden = transformer_outputs[0] pred_hid = last_hidden[:, -tgt_len:] outputs = transformer_outputs[1:] if self.sample_softmax > 0 and training: raise NotImplementedError else: # pred_hid = tf.reshape(pred_hid, (-1, shape_list(pred_hid)[-1])) softmax_output = self.crit([pred_hid, labels], training=training) # softmax_output = tf.reshape(softmax_output, (bsz, tgt_len, -1)) outputs = [softmax_output] + outputs return outputs # logits, new_mems, (all hidden states), (all attentions)