mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
WIP XLNet
This commit is contained in:
parent
f851fb55ca
commit
32aabe8c33
@ -95,7 +95,7 @@ except (ImportError, AssertionError):
|
||||
if _tf_available:
|
||||
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary
|
||||
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelWithLMHead)
|
||||
|
||||
@ -107,7 +107,7 @@ if _tf_available:
|
||||
load_bert_pt_weights_in_tf2,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFGPT2Embeddings,
|
||||
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer,
|
||||
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel,
|
||||
load_gpt2_pt_weights_in_tf2,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
@ -54,6 +54,7 @@ class PretrainedConfig(object):
|
||||
self.output_attentions = kwargs.pop('output_attentions', False)
|
||||
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
|
||||
self.torchscript = kwargs.pop('torchscript', False)
|
||||
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
|
||||
self.pruned_heads = kwargs.pop('pruned_heads', {})
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
|
@ -28,7 +28,8 @@ from io import open
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D, TFSequenceSummary, shape_list
|
||||
from .modeling_tf_utils import (TFPreTrainedModel, TFConv1D, TFSharedEmbeddings,
|
||||
TFSequenceSummary, shape_list)
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
@ -65,6 +66,7 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
||||
|
||||
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
||||
weight_value_tuples = []
|
||||
all_pytorch_weights = set(list(state_dict.keys()))
|
||||
for symbolic_weight in symbolic_weights:
|
||||
name = symbolic_weight.name
|
||||
name = name.replace(':0', '')
|
||||
@ -100,13 +102,13 @@ def load_gpt2_pt_weights_in_tf2(tf_model, config, pytorch_checkpoint_path):
|
||||
|
||||
weight_value_tuples.append((symbolic_weight, array))
|
||||
|
||||
state_dict.pop(name)
|
||||
all_pytorch_weights.discard(name)
|
||||
|
||||
K.batch_set_value(weight_value_tuples)
|
||||
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
assert not state_dict, "Weights not loaded: {}".format(list(state_dict.keys()))
|
||||
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
|
||||
|
||||
return tf_model
|
||||
|
||||
@ -267,65 +269,6 @@ class TFBlock(tf.keras.layers.Layer):
|
||||
outputs = [x] + output_attn[1:]
|
||||
return outputs # x, present, (attentions)
|
||||
|
||||
class TFGPT2Embeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position and token_type embeddings.
|
||||
"""
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFGPT2Embeddings, self).__init__(**kwargs)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
self.weight = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=tf.random_normal_initializer(
|
||||
mean=0., stddev=self.hidden_size**-0.5))
|
||||
super(TFGPT2Embeddings, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding"):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, input_ids):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
return tf.gather(self.weight, input_ids)
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [..., hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [..., vocab_size].
|
||||
"""
|
||||
first_dims = shape_list(inputs)[:-1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.weight, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, first_dims + [self.vocab_size])
|
||||
|
||||
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
@ -336,10 +279,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_embd = config.n_embd
|
||||
|
||||
self.wte = TFGPT2Embeddings(config, name='wte')
|
||||
self.wte = TFSharedEmbeddings(config.vocab_size, config.hidden_size, name='wte')
|
||||
self.wpe = tf.keras.layers.Embedding(config.n_positions, config.n_embd, name='wpe')
|
||||
self.drop = tf.keras.layers.Dropout(config.embd_pdrop)
|
||||
self.h = [TFBlock(config.n_ctx, config, scale=True, name='h_{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.h = [TFBlock(config.n_ctx,
|
||||
config,
|
||||
scale=True,
|
||||
name='h_{}'.format(i)) for i in range(config.n_layer)]
|
||||
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name='ln_f')
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
|
@ -288,6 +288,69 @@ class TFConv1D(tf.keras.layers.Layer):
|
||||
return x
|
||||
|
||||
|
||||
class TFSharedEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct shared token embeddings.
|
||||
"""
|
||||
def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
|
||||
super(TFSharedEmbeddings, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
initializer_range = self.hidden_size**-0.5 if self.initializer_range is None else self.initializer_range
|
||||
self.weight = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=tf.random_normal_initializer(
|
||||
mean=0., stddev=initializer_range))
|
||||
super(TFSharedEmbeddings, self).build(input_shape)
|
||||
|
||||
def call(self, inputs, mode="embedding"):
|
||||
"""Get token embeddings of inputs.
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Returns:
|
||||
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
||||
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
||||
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(inputs)
|
||||
elif mode == "linear":
|
||||
return self._linear(inputs)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, input_ids):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
return tf.gather(self.weight, input_ids)
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""Computes logits by running inputs through a linear layer.
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [..., hidden_size]
|
||||
Returns:
|
||||
float32 tensor with shape [..., vocab_size].
|
||||
"""
|
||||
first_dims = shape_list(inputs)[:-1]
|
||||
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.weight, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, first_dims + [self.vocab_size])
|
||||
|
||||
|
||||
class TFSequenceSummary(tf.keras.layers.Layer):
|
||||
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
||||
Args of the config class:
|
||||
|
1121
pytorch_transformers/modeling_tf_xlnet.py
Normal file
1121
pytorch_transformers/modeling_tf_xlnet.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -262,7 +262,7 @@ class TFCommonTestCases:
|
||||
# self.assertEqual(len(params_tied_2), len(params_tied))
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=tf.int32):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
@ -275,7 +275,7 @@ def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
return tf.constant(values, shape=shape)
|
||||
return tf.constant(values, shape=shape, dtype=dtype)
|
||||
|
||||
|
||||
class TFModelUtilsTest(unittest.TestCase):
|
||||
|
341
pytorch_transformers/tests/modeling_tf_xlnet_test.py
Normal file
341
pytorch_transformers/tests/modeling_tf_xlnet_test.py
Normal file
@ -0,0 +1,341 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import XLNetConfig, is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from pytorch_transformers.modeling_tf_xlnet import (TFXLNetModel, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
# XLNetLMHeadModel,
|
||||
# XLNetForSequenceClassification, XLNetForQuestionAnswering)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes=(TFXLNetModel, ) if is_tf_available() else ()
|
||||
# all_model_classes=(TFXLNetModel, TFXLNetLMHeadModel,
|
||||
# TFXLNetForSequenceClassification, TFXLNetForQuestionAnswering) if is_tf_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
class TFXLNetModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
mem_len=10,
|
||||
clamp_len=-1,
|
||||
reuse_len=15,
|
||||
is_training=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
cutoffs=[10, 50, 80],
|
||||
hidden_size=32,
|
||||
num_attention_heads=4,
|
||||
d_inner=128,
|
||||
num_hidden_layers=5,
|
||||
max_position_embeddings=10,
|
||||
type_sequence_label_size=2,
|
||||
untie_r=True,
|
||||
bi_data=False,
|
||||
same_length=False,
|
||||
initializer_range=0.05,
|
||||
seed=1,
|
||||
type_vocab_size=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.mem_len = mem_len
|
||||
# self.key_len = seq_length + mem_len
|
||||
self.clamp_len = clamp_len
|
||||
self.reuse_len = reuse_len
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.cutoffs = cutoffs
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_inner = d_inner
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.bi_data = bi_data
|
||||
self.untie_r = untie_r
|
||||
self.same_length = same_length
|
||||
self.initializer_range = initializer_range
|
||||
self.seed = seed
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], 2, dtype=tf.float32)
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = tf.zeros((self.batch_size, self.seq_length + 1, self.seq_length), dtype=tf.float32)
|
||||
perm_mask_last = tf.ones((self.batch_size, self.seq_length + 1, 1), dtype=tf.float32)
|
||||
perm_mask = tf.concat([perm_mask, perm_mask_last], axis=-1)
|
||||
# perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = tf.zeros((self.batch_size, 1, self.seq_length), dtype=torch.float32)
|
||||
target_mapping_last = tf.ones((self.batch_size, 1, 1), dtype=torch.float32)
|
||||
target_mapping = tf.concat([target_mapping, target_mapping_last], axis=-1)
|
||||
# target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
|
||||
sequence_labels = None
|
||||
lm_labels = None
|
||||
is_impossible_labels = None
|
||||
if self.use_labels:
|
||||
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)
|
||||
|
||||
config = XLNetConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
n_head=self.num_attention_heads,
|
||||
d_inner=self.d_inner,
|
||||
n_layer=self.num_hidden_layers,
|
||||
untie_r=self.untie_r,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
mem_len=self.mem_len,
|
||||
clamp_len=self.clamp_len,
|
||||
same_length=self.same_length,
|
||||
reuse_len=self.reuse_len,
|
||||
bi_data=self.bi_data,
|
||||
initializer_range=self.initializer_range,
|
||||
num_labels=self.type_sequence_label_size)
|
||||
|
||||
return (config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels)
|
||||
|
||||
def set_seed(self):
|
||||
random.seed(self.seed)
|
||||
tf.random.set_seed(self.seed)
|
||||
|
||||
def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
model = TFXLNetModel(config)
|
||||
|
||||
inputs = {'input_ids': input_ids,
|
||||
'input_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
|
||||
_, _ = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
|
||||
outputs, mems_1 = model(inputs)
|
||||
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for m in mems_1],
|
||||
"outputs": outputs.numpy(),
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
pass
|
||||
# model = XLNetLMHeadModel(config)
|
||||
# model.eval()
|
||||
|
||||
# loss_1, all_logits_1, mems_1 = model(input_ids_1, token_type_ids=segment_ids, labels=lm_labels)
|
||||
|
||||
# loss_2, all_logits_2, mems_2 = model(input_ids_2, token_type_ids=segment_ids, labels=lm_labels, mems=mems_1)
|
||||
|
||||
# logits, _ = model(input_ids_q, perm_mask=perm_mask, target_mapping=target_mapping)
|
||||
|
||||
# result = {
|
||||
# "loss_1": loss_1,
|
||||
# "mems_1": mems_1,
|
||||
# "all_logits_1": all_logits_1,
|
||||
# "loss_2": loss_2,
|
||||
# "mems_2": mems_2,
|
||||
# "all_logits_2": all_logits_2,
|
||||
# }
|
||||
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["loss_1"].size()),
|
||||
# [])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["all_logits_1"].size()),
|
||||
# [self.batch_size, self.seq_length, self.vocab_size])
|
||||
# self.parent.assertListEqual(
|
||||
# list(list(mem.size()) for mem in result["mems_1"]),
|
||||
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["loss_2"].size()),
|
||||
# [])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["all_logits_2"].size()),
|
||||
# [self.batch_size, self.seq_length, self.vocab_size])
|
||||
# self.parent.assertListEqual(
|
||||
# list(list(mem.size()) for mem in result["mems_2"]),
|
||||
# [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_qa(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
pass
|
||||
# model = XLNetForQuestionAnswering(config)
|
||||
# model.eval()
|
||||
|
||||
# outputs = model(input_ids_1)
|
||||
# start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
|
||||
|
||||
# outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
# end_positions=sequence_labels,
|
||||
# cls_index=sequence_labels,
|
||||
# is_impossible=is_impossible_labels,
|
||||
# p_mask=input_mask)
|
||||
|
||||
# outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
# end_positions=sequence_labels,
|
||||
# cls_index=sequence_labels,
|
||||
# is_impossible=is_impossible_labels)
|
||||
|
||||
# total_loss, mems = outputs
|
||||
|
||||
# outputs = model(input_ids_1, start_positions=sequence_labels,
|
||||
# end_positions=sequence_labels)
|
||||
|
||||
# total_loss, mems = outputs
|
||||
|
||||
# result = {
|
||||
# "loss": total_loss,
|
||||
# "start_top_log_probs": start_top_log_probs,
|
||||
# "start_top_index": start_top_index,
|
||||
# "end_top_log_probs": end_top_log_probs,
|
||||
# "end_top_index": end_top_index,
|
||||
# "cls_logits": cls_logits,
|
||||
# "mems": mems,
|
||||
# }
|
||||
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["loss"].size()),
|
||||
# [])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["start_top_log_probs"].size()),
|
||||
# [self.batch_size, model.config.start_n_top])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["start_top_index"].size()),
|
||||
# [self.batch_size, model.config.start_n_top])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["end_top_log_probs"].size()),
|
||||
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["end_top_index"].size()),
|
||||
# [self.batch_size, model.config.start_n_top * model.config.end_n_top])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["cls_logits"].size()),
|
||||
# [self.batch_size])
|
||||
# self.parent.assertListEqual(
|
||||
# list(list(mem.size()) for mem in result["mems"]),
|
||||
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def create_and_check_xlnet_sequence_classif(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||
pass
|
||||
# model = XLNetForSequenceClassification(config)
|
||||
# model.eval()
|
||||
|
||||
# logits, mems_1 = model(input_ids_1)
|
||||
# loss, logits, mems_1 = model(input_ids_1, labels=sequence_labels)
|
||||
|
||||
# result = {
|
||||
# "loss": loss,
|
||||
# "mems_1": mems_1,
|
||||
# "logits": logits,
|
||||
# }
|
||||
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["loss"].size()),
|
||||
# [])
|
||||
# self.parent.assertListEqual(
|
||||
# list(result["logits"].size()),
|
||||
# [self.batch_size, self.type_sequence_label_size])
|
||||
# self.parent.assertListEqual(
|
||||
# list(list(mem.size()) for mem in result["mems_1"]),
|
||||
# [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||
target_mapping, segment_ids, lm_labels,
|
||||
sequence_labels, is_impossible_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids_1}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLNetModelTest.TFXLNetModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_xlnet_base_model(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
||||
|
||||
def test_xlnet_lm_head(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_lm_head(*config_and_inputs)
|
||||
|
||||
def test_xlnet_sequence_classif(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_sequence_classif(*config_and_inputs)
|
||||
|
||||
def test_xlnet_qa(self):
|
||||
self.model_tester.set_seed()
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlnet_qa(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user