mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
update model, conversion script, tests and template
This commit is contained in:
parent
076a207935
commit
ba10065c4b
@ -26,9 +26,9 @@ from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, xxx_config_file, pytorch_dump_path):
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = XxxConfig.from_json_file(xxx_config_file)
|
||||
config = XxxConfig.from_json_file(config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = XxxForPreTraining(config)
|
||||
|
||||
@ -48,11 +48,11 @@ if __name__ == "__main__":
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--xxx_config_file",
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
help = "The config json file corresponding to the pre-trained XXX model. \n"
|
||||
help = "The config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture.")
|
||||
parser.add_argument("--pytorch_dump_path",
|
||||
default = None,
|
||||
@ -61,5 +61,5 @@ if __name__ == "__main__":
|
||||
help = "Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.xxx_config_file,
|
||||
args.config_file,
|
||||
args.pytorch_dump_path)
|
||||
|
@ -97,6 +97,7 @@ if is_torch_available():
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
from .modeling_t5 import (T5PreTrainedModel, T5Model, T5WithLMHeadModel,
|
||||
load_tf_weights_in_t5,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
# Optimization
|
||||
|
@ -57,8 +57,7 @@ class T5Config(PretrainedConfig):
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`T5Model`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing).
|
||||
layer_norm_eps: The epsilon used by LayerNorm.
|
||||
"""
|
||||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
@ -67,25 +66,27 @@ class T5Config(PretrainedConfig):
|
||||
vocab_size_or_config_json_file=32128,
|
||||
n_positions=512,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
d_ff=2048,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
num_layers=6,
|
||||
num_heads=8,
|
||||
relative_attention_num_buckets=32,
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
**kwargs):
|
||||
super(T5Config, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
|
||||
self.n_positions = n_positions
|
||||
self.d_model = d_model
|
||||
self.d_kv = d_kv
|
||||
self.d_ff = d_ff
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
|
||||
if isinstance(vocab_size_or_config_json_file, six.string_types):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
|
@ -21,16 +21,16 @@ from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from transformers import T5Config, T5ForPreTraining, load_tf_weights_in_t5
|
||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, t5_config_file, pytorch_dump_path):
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
# Initialise PyTorch model
|
||||
config = T5Config.from_json_file(t5_config_file)
|
||||
config = T5Config.from_json_file(config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = T5ForPreTraining(config)
|
||||
model = T5Model(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
|
||||
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the TensorFlow checkpoint path.")
|
||||
parser.add_argument("--t5_config_file",
|
||||
parser.add_argument("--config_file",
|
||||
default = None,
|
||||
type = str,
|
||||
required = True,
|
||||
@ -61,5 +61,5 @@ if __name__ == "__main__":
|
||||
help = "Path to the output PyTorch model.")
|
||||
args = parser.parse_args()
|
||||
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||
args.t5_config_file,
|
||||
args.config_file,
|
||||
args.pytorch_dump_path)
|
||||
|
@ -65,34 +65,40 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
tf_weights = {}
|
||||
for name, shape in init_vars:
|
||||
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
tf_weights[name] = array
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
for txt_name in names:
|
||||
name = txt_name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
tf_weights.pop(txt_name, None)
|
||||
continue
|
||||
if '_slot_' in name[-1]:
|
||||
logger.info("Skipping {}".format("/".join(name)))
|
||||
tf_weights.pop(txt_name, None)
|
||||
continue
|
||||
pointer = model
|
||||
array = tf_weights[txt_name]
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
l = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
l = [m_name]
|
||||
if l[0] == 'kernel' or l[0] == 'gamma':
|
||||
if l[0] in ['kernel', 'scale', 'embedding']:
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif l[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif l[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
# elif l[0] == 'scale':
|
||||
# pointer = getattr(pointer, 'weight')
|
||||
# elif l[0] == 'output_bias' or l[0] == 'beta':
|
||||
# pointer = getattr(pointer, 'bias')
|
||||
# elif l[0] == 'squad':
|
||||
# pointer = getattr(pointer, 'classifier')
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, l[0])
|
||||
@ -102,9 +108,10 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
if len(l) >= 2:
|
||||
num = int(l[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
if l[0] not in ['kernel', 'scale', 'embedding']:
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
if l[0] != 'embedding':
|
||||
logger.info("Transposing numpy weight of shape {} for {}".format(array.shape, name))
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert pointer.shape == array.shape
|
||||
@ -112,7 +119,11 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
pointer.data = torch.from_numpy(array.astype(np.float32))
|
||||
tf_weights.pop(txt_name, None)
|
||||
|
||||
logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
# logger.info("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
return model
|
||||
|
||||
|
||||
@ -163,10 +174,13 @@ class T5Attention(nn.Module):
|
||||
self.output_attentions = config.output_attentions
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.dim = config.d_model
|
||||
self.d_kv = config.d_kv
|
||||
self.n_heads = config.num_heads
|
||||
self.dropout = config.dropout_rate
|
||||
assert self.dim % self.n_heads == 0
|
||||
assert self.dim // self.n_heads == self.d_kv
|
||||
|
||||
# Mesh TensorFlow initialization to avoid scaling before softmax
|
||||
self.q = nn.Linear(self.dim, self.dim, bias=False)
|
||||
self.k = nn.Linear(self.dim, self.dim, bias=False)
|
||||
self.v = nn.Linear(self.dim, self.dim, bias=False)
|
||||
@ -312,8 +326,9 @@ class T5Attention(nn.Module):
|
||||
scores += position_bias
|
||||
|
||||
if mask is not None:
|
||||
mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
scores += mask
|
||||
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
|
||||
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
||||
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
||||
@ -378,34 +393,35 @@ class T5Block(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super(T5Block, self).__init__()
|
||||
self.is_decoder = config.is_decoder
|
||||
self.layer_000 = T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)
|
||||
self.layer = nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
if self.is_decoder:
|
||||
self.layer_001 = T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias)
|
||||
self.layer_002 = T5LayerFF(config)
|
||||
self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias))
|
||||
self.layer.append(T5LayerFF(config))
|
||||
else:
|
||||
self.layer_001 = T5LayerFF(config)
|
||||
self.layer.append(T5LayerFF(config))
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_bias=None,
|
||||
encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None,
|
||||
head_mask=None):
|
||||
self_attention_outputs = self.layer_000(hidden_states,
|
||||
self_attention_outputs = self.layer[0](hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask)
|
||||
hidden_states = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:]
|
||||
|
||||
if self.is_decoder:
|
||||
cross_attention_outputs = self.layer_001(hidden_states,
|
||||
if not self.is_decoder:
|
||||
hidden_states = self.layer[1](hidden_states)
|
||||
else:
|
||||
cross_attention_outputs = self.layer[1](hidden_states,
|
||||
kv=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
position_bias=encoder_decoder_position_bias,
|
||||
head_mask=head_mask)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
outputs = cross_attention_outputs[1:] + outputs
|
||||
hidden_states = self.layer_002(hidden_states)
|
||||
else:
|
||||
hidden_states = self.layer_001(hidden_states)
|
||||
hidden_states = self.layer[2](hidden_states)
|
||||
|
||||
outputs = (hidden_states,) + outputs # add attentions if we output them
|
||||
return outputs
|
||||
@ -422,15 +438,36 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
factor = self.config.initializer_factor # Used for testing weights initialization
|
||||
if isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(factor*1.0)
|
||||
elif isinstance(module, T5Model):
|
||||
# Mesh TensorFlow embeddings initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
|
||||
module.shared.weight.data.normal_(mean=0.0, std=factor*1.0)
|
||||
elif isinstance(module, T5DenseReluDense):
|
||||
# Mesh TensorFlow FF initialization
|
||||
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
|
||||
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
|
||||
module.wi.weight.data.normal_(mean=0.0, std=factor*((self.config.d_model) ** -0.5))
|
||||
if hasattr(module.wi, 'bias') and module.wi.bias is not None:
|
||||
module.wi.bias.data.zero_()
|
||||
module.wo.weight.data.normal_(mean=0.0, std=factor*((self.config.d_ff) ** -0.5))
|
||||
if hasattr(module.wo, 'bias') and module.wo.bias is not None:
|
||||
module.wo.bias.data.zero_()
|
||||
elif isinstance(module, T5Attention):
|
||||
# Mesh TensorFlow attention initialization to avoid scaling before softmax
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
|
||||
d_model = self.config.d_model
|
||||
d_kv = self.config.d_kv
|
||||
n_heads = self.config.num_heads
|
||||
module.q.weight.data.normal_(mean=0.0, std=factor*((d_model * d_kv) ** -0.5))
|
||||
module.k.weight.data.normal_(mean=0.0, std=factor*(d_model ** -0.5))
|
||||
module.v.weight.data.normal_(mean=0.0, std=factor*(d_model ** -0.5))
|
||||
module.o.weight.data.normal_(mean=0.0, std=factor*((n_heads * d_kv) ** -0.5))
|
||||
if module.has_relative_attention_bias:
|
||||
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor*((d_model) ** -0.5))
|
||||
|
||||
|
||||
class T5Stack(T5PreTrainedModel):
|
||||
@ -440,7 +477,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
self.blocks = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0))
|
||||
self.block = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0))
|
||||
for i in range(config.num_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
@ -518,7 +555,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
all_attentions = ()
|
||||
position_bias = None
|
||||
encoder_decoder_position_bias = None
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
for i, layer_module in enumerate(self.block):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@ -724,9 +761,10 @@ class T5WithLMHeadModel(T5PreTrainedModel):
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(T5WithLMHeadModel, self).__init__(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
self.transformer = T5Model(config)
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size)
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@ -738,15 +776,18 @@ class T5WithLMHeadModel(T5PreTrainedModel):
|
||||
outputs = self.transformer(**kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
# Rescale output before projecting on vocab
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
sequence_output = sequence_output * (self.model_dim ** -0.5)
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
|
||||
outputs = (lm_logits,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = lm_labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
outputs = (loss,) + outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
return outputs # (lm_loss), lm_logits, (hidden_states), (attentions)
|
||||
|
@ -59,7 +59,7 @@ else:
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__.keys():
|
||||
if '_range' in key or '_std' in key:
|
||||
if '_range' in key or '_std' in key or 'initializer_factor' in key:
|
||||
setattr(configs_no_init, key, 0.0)
|
||||
return configs_no_init
|
||||
|
||||
@ -83,20 +83,24 @@ class CommonTestCases:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
out_2 = outputs[0].numpy()
|
||||
out_2[np.isnan(out_2)] = 0
|
||||
|
||||
with TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
with torch.no_grad():
|
||||
after_outputs = model(**inputs_dict)
|
||||
|
||||
# Make sure we don't have nans
|
||||
# # Make sure we don't have nans
|
||||
out_1 = after_outputs[0].numpy()
|
||||
out_2 = outputs[0].numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
out_1[np.isnan(out_1)] = 0
|
||||
|
||||
out_1 = out_1 - out_2
|
||||
amax = np.amax(out_1)
|
||||
amin = np.amin(out_1)
|
||||
self.assertLessEqual(max(amax, -amin), 1e-5)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -127,24 +131,25 @@ class CommonTestCases:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(**inputs_dict)
|
||||
self_attentions = outputs[-1]
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
cross_attentions = outputs[-2]
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
decoder_attentions = outputs[(out_len // 2)-1]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
|
@ -57,7 +57,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=0.002,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@ -74,7 +74,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
@ -93,11 +93,12 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
n_positions=self.n_positions,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
d_kv=self.hidden_size // self.num_attention_heads,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_range=self.initializer_range)
|
||||
initializer_factor=self.initializer_factor)
|
||||
|
||||
return (config, input_ids, input_mask, token_labels)
|
||||
|
||||
@ -130,8 +131,9 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||
model = T5WithLMHeadModel(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(encoder_input_ids=input_ids, decoder_input_ids=input_ids,
|
||||
outputs = model(encoder_input_ids=input_ids, decoder_input_ids=input_ids,
|
||||
decoder_attention_mask=input_mask, decoder_lm_labels=token_labels)
|
||||
loss, prediction_scores = outputs[0], outputs[1]
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user