mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
adding tests and updating model
This commit is contained in:
parent
73f2c342f5
commit
076a207935
@ -42,6 +42,7 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
||||
from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer
|
||||
from .tokenization_t5 import T5Tokenizer
|
||||
|
||||
# Configurations
|
||||
from .configuration_utils import PretrainedConfig
|
||||
@ -52,10 +53,10 @@ from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CON
|
||||
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
# Modeling
|
||||
if is_torch_available():
|
||||
@ -69,10 +70,10 @@ if is_torch_available():
|
||||
BertForTokenClassification, BertForQuestionAnswering,
|
||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
|
||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
|
||||
GPT2LMHeadModel, GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
@ -95,6 +96,8 @@ if is_torch_available():
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
||||
from .modeling_t5 import (T5PreTrainedModel, T5Model, T5WithLMHeadModel,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
# Optimization
|
||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||
|
@ -64,44 +64,29 @@ class T5Config(PretrainedConfig):
|
||||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size_or_config_json_file=32128,
|
||||
n_positions=512,
|
||||
d_model=512,
|
||||
d_ff=2048,
|
||||
num_layers=12,
|
||||
num_heads=12,
|
||||
relative_attention_num_buckets=32,
|
||||
dropout_rate=0.1,
|
||||
layer_norm_epsilon=1e-6,
|
||||
initializer_range=0.02,
|
||||
|
||||
num_labels=1,
|
||||
summary_type='cls_index',
|
||||
summary_use_proj=True,
|
||||
summary_activation=None,
|
||||
summary_proj_to_labels=True,
|
||||
summary_first_dropout=0.1,
|
||||
**kwargs):
|
||||
super(T5Config, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, six.string_types) else -1
|
||||
self.n_ctx = n_ctx
|
||||
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.d_model = d_model
|
||||
self.d_ff = d_ff
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
self.num_labels = num_labels
|
||||
self.summary_type = summary_type
|
||||
self.summary_use_proj = summary_use_proj
|
||||
self.summary_activation = summary_activation
|
||||
self.summary_first_dropout = summary_first_dropout
|
||||
self.summary_proj_to_labels = summary_proj_to_labels
|
||||
if isinstance(vocab_size_or_config_json_file, six.string_types):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
@ -119,12 +104,12 @@ class T5Config(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
return self.n_embd
|
||||
return self.d_model
|
||||
|
||||
@property
|
||||
def num_attention_heads(self):
|
||||
return self.n_head
|
||||
return self.num_heads
|
||||
|
||||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.n_layer
|
||||
return self.num_layers
|
||||
|
@ -20,8 +20,8 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
import copy
|
||||
import itertools
|
||||
from io import open
|
||||
|
||||
@ -30,7 +30,7 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
from .modeling_utils import PreTrainedModel
|
||||
from .configuration_t5 import T5Config
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
@ -127,7 +127,7 @@ class T5DenseReluDense(nn.Module):
|
||||
super(T5DenseReluDense, self).__init__()
|
||||
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
h = self.wi(hidden_states)
|
||||
@ -141,8 +141,8 @@ class T5LayerFF(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5LayerFF, self).__init__()
|
||||
self.DenseReluDense = T5DenseReluDense(config)
|
||||
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
@ -157,6 +157,7 @@ class T5Attention(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super(T5Attention, self).__init__()
|
||||
self.layer_id = next(T5Attention.NEW_ID)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
@ -231,7 +232,7 @@ class T5Attention(nn.Module):
|
||||
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
|
||||
n = torch.abs(n)
|
||||
else:
|
||||
n = torch.max(n, 0)
|
||||
n = torch.max(n, torch.zeros_like(n))
|
||||
# now n is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -242,7 +243,7 @@ class T5Attention(nn.Module):
|
||||
val_if_large = max_exact + (
|
||||
torch.log(n.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long)
|
||||
val_if_large = torch.min(val_if_large, num_buckets - 1)
|
||||
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
|
||||
|
||||
ret += torch.where(is_small, n, val_if_large)
|
||||
return ret
|
||||
@ -259,7 +260,7 @@ class T5Attention(nn.Module):
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
def forward(self, input, mask, kv=None, position_bias=None, cache=None, head_mask=None):
|
||||
def forward(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
"""
|
||||
@ -273,7 +274,6 @@ class T5Attention(nn.Module):
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
n_heads = self.n_heads
|
||||
dim_per_head = self.dim // n_heads
|
||||
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
|
||||
|
||||
def shape(x):
|
||||
""" projection """
|
||||
@ -311,8 +311,9 @@ class T5Attention(nn.Module):
|
||||
position_bias = self.compute_bias(qlen, klen)
|
||||
scores += position_bias
|
||||
|
||||
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
if mask is not None:
|
||||
mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
|
||||
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
||||
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
||||
@ -338,13 +339,13 @@ class T5LayerSelfAttention(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super(T5LayerSelfAttention, self).__init__()
|
||||
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
|
||||
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(norm_x,
|
||||
attention_mask=attention_mask,
|
||||
mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask)
|
||||
y = attention_output[0]
|
||||
@ -357,14 +358,14 @@ class T5LayerCrossAttention(nn.Module):
|
||||
def __init__(self, config, has_relative_attention_bias=False):
|
||||
super(T5LayerCrossAttention, self).__init__()
|
||||
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
|
||||
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(norm_x,
|
||||
mask=attention_mask,
|
||||
kv=kv,
|
||||
attention_mask=attention_mask,
|
||||
position_bias=position_bias,
|
||||
head_mask=head_mask)
|
||||
y = attention_output[0]
|
||||
@ -410,13 +411,41 @@ class T5Block(nn.Module):
|
||||
return outputs
|
||||
|
||||
|
||||
class T5Stack(nn.Module):
|
||||
class T5PreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = T5Config
|
||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_t5
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
class T5Stack(T5PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(T5Stack, self).__init__()
|
||||
super(T5Stack, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.is_decoder = config.is_decoder
|
||||
|
||||
self.blocks = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0))
|
||||
for i in range(config.num_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.final_layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self,
|
||||
hidden_states,
|
||||
@ -426,10 +455,10 @@ class T5Stack(nn.Module):
|
||||
head_mask=None):
|
||||
|
||||
batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
|
||||
encoder_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, seq_length).to(hidden_states.device)
|
||||
if encoder_attention_mask is None:
|
||||
if self.is_decoder and encoder_attention_mask is None:
|
||||
encoder_seq_length = encoder_hidden_states.shape[1]
|
||||
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(hidden_states.device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
@ -444,6 +473,7 @@ class T5Stack(nn.Module):
|
||||
if self.config.is_decoder:
|
||||
seq_ids = torch.arange(seq_length, device=hidden_states.device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
causal_mask = causal_mask.to(attention_mask)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
@ -456,15 +486,18 @@ class T5Stack(nn.Module):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
if self.is_decoder:
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@ -474,18 +507,18 @@ class T5Stack(nn.Module):
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
head_mask = head_mask.expand(self.config.num_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
head_mask = [None] * self.config.num_layers
|
||||
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
position_bias = None
|
||||
encoder_decoder_position_bias = None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
for i, layer_module in enumerate(self.blocks):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
@ -498,8 +531,9 @@ class T5Stack(nn.Module):
|
||||
head_mask=head_mask[i])
|
||||
hidden_states = layer_outputs[0]
|
||||
if i == 0:
|
||||
position_bias = layer_outputs[2] if len(layer_outputs) > 3 else None
|
||||
encoder_decoder_position_bias = layer_outputs[4] if len(layer_outputs) > 5 else None
|
||||
position_bias = layer_outputs[2 if self.output_attentions else 1]
|
||||
if self.is_decoder:
|
||||
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
@ -519,27 +553,6 @@ class T5Stack(nn.Module):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class T5PreTrainedModel(PreTrainedEncoderDecoder):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = T5Config
|
||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_t5
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
|
||||
T5_START_DOCSTRING = r""" The T5 model was proposed in
|
||||
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
|
||||
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
|
||||
@ -620,7 +633,7 @@ class T5Model(T5PreTrainedModel):
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(T5Model, self).__init__(config)
|
||||
self.shared = nn.Embeddings(config.vocab_size, config.d_model)
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
||||
encoder_config = copy.deepcopy(config)
|
||||
self.encoder = T5Stack(encoder_config)
|
||||
@ -631,7 +644,6 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
|
||||
@ -646,17 +658,17 @@ class T5Model(T5PreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
def forward(self, **kwargs):
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_common = dict((k, v) for k, v in kwargs.items()
|
||||
if not k.startswith("encoder_") and not k.startswith("decoder_"))
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder = kwargs_common.copy()
|
||||
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
|
||||
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder.update(dict((k[len("encoder_"):], v) for k, v in kwargs.items() if k.startswith("encoder_")))
|
||||
kwargs_decoder.update(dict((k[len("decoder_"):], v) for k, v in kwargs.items() if k.startswith("decoder_")))
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
@ -680,7 +692,7 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """,
|
||||
T5_START_DOCSTRING, T5_INPUTS_DOCSTRING)
|
||||
class T5WithLMHead(T5PreTrainedModel):
|
||||
class T5WithLMHeadModel(T5PreTrainedModel):
|
||||
r"""
|
||||
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Labels for computing the masked language modeling loss.
|
||||
@ -704,14 +716,14 @@ class T5WithLMHead(T5PreTrainedModel):
|
||||
Examples::
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained('t5-base-uncased')
|
||||
model = T5ForMaskedLM.from_pretrained('t5-base-uncased')
|
||||
model = T5WithLMHeadModel.from_pretrained('t5-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, lm_labels=input_ids)
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(T5ForMaskedLM, self).__init__(config)
|
||||
super(T5WithLMHeadModel, self).__init__(config)
|
||||
|
||||
self.transformer = T5Model(config)
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size)
|
||||
@ -721,11 +733,12 @@ class T5WithLMHead(T5PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
outputs = self.transformer(encoder_input_ids, decoder_input_ids, **kwargs)
|
||||
def forward(self, **kwargs):
|
||||
lm_labels = kwargs.pop('decoder_lm_labels', None)
|
||||
outputs = self.transformer(**kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
lm_logits = self.cls(sequence_output)
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
|
||||
outputs = (lm_logits,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
|
@ -73,6 +73,7 @@ class CommonTestCases:
|
||||
test_pruning = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
is_encoder_decoder = False
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@ -114,10 +115,9 @@ class CommonTestCases:
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
first, second = model(inputs_dict["input_ids"])[0], model(inputs_dict["input_ids"])[0]
|
||||
first, second = model(**inputs_dict)[0], model(**inputs_dict)[0]
|
||||
self.assertEqual(first.ne(second).sum().item(), 0)
|
||||
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@ -127,31 +127,42 @@ class CommonTestCases:
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
self_attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
cross_attentions = outputs[-2]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
model.eval()
|
||||
outputs = model(**inputs_dict)
|
||||
self.assertEqual(out_len+1, len(outputs))
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self_attentions = outputs[-1]
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads,
|
||||
self.model_tester.seq_length,
|
||||
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length])
|
||||
@ -214,7 +225,6 @@ class CommonTestCases:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def test_headmasking(self):
|
||||
if not self.test_head_masking:
|
||||
return
|
||||
@ -268,7 +278,6 @@ class CommonTestCases:
|
||||
self.assertNotEqual(
|
||||
attentions[-1][..., -1, :, :].flatten().sum().item(), 0.0)
|
||||
|
||||
|
||||
def test_head_pruning(self):
|
||||
if not self.test_pruning:
|
||||
return
|
||||
@ -411,7 +420,6 @@ class CommonTestCases:
|
||||
|
||||
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
|
||||
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
176
transformers/tests/modeling_t5_test.py
Normal file
176
transformers/tests/modeling_t5_test.py
Normal file
@ -0,0 +1,176 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (T5Config, T5Model, T5WithLMHeadModel)
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
|
||||
|
||||
class T5ModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (T5Model, T5WithLMHeadModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
is_encoder_decoder = True
|
||||
|
||||
class T5ModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
n_positions=14,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
d_ff=37,
|
||||
relative_attention_num_buckets=8,
|
||||
dropout_rate=0.1,
|
||||
initializer_range=0.02,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.d_ff = d_ff
|
||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
||||
self.dropout_rate = dropout_rate
|
||||
self.initializer_range = initializer_range
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_labels = None
|
||||
if self.use_labels:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = T5Config(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
n_positions=self.n_positions,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
num_layers=self.num_hidden_layers,
|
||||
num_heads=self.num_attention_heads,
|
||||
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
||||
dropout_rate=self.dropout_rate,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return (config, input_ids, input_mask, token_labels)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_t5_model(self, config, input_ids, input_mask, token_labels):
|
||||
model = T5Model(config=config)
|
||||
model.eval()
|
||||
encoder_output, decoder_output = model(encoder_input_ids=input_ids,
|
||||
decoder_input_ids=input_ids,
|
||||
decoder_attention_mask=input_mask)
|
||||
encoder_output, decoder_output = model(encoder_input_ids=input_ids,
|
||||
decoder_input_ids=input_ids)
|
||||
|
||||
result = {
|
||||
"encoder_output": encoder_output,
|
||||
"decoder_output": decoder_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["encoder_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(
|
||||
list(result["decoder_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
|
||||
model = T5WithLMHeadModel(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(encoder_input_ids=input_ids, decoder_input_ids=input_ids,
|
||||
decoder_attention_mask=input_mask, decoder_lm_labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, token_labels) = config_and_inputs
|
||||
inputs_dict = {'encoder_input_ids': input_ids,
|
||||
'decoder_input_ids': input_ids,
|
||||
'decoder_attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = T5ModelTest.T5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_t5_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model(*config_and_inputs)
|
||||
|
||||
def test_with_lm_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = T5Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
190
transformers/tests/modeling_tf_t5_test.py
Normal file
190
transformers/tests/modeling_tf_t5_test.py
Normal file
@ -0,0 +1,190 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor)
|
||||
from .configuration_common_test import ConfigTester
|
||||
|
||||
from transformers import T5Config, is_tf_available
|
||||
|
||||
if False: # is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_t5 import (TFT5Model, TFT5WithLMHeadModel,TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
|
||||
class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
|
||||
all_model_classes = (TFT5Model, TFT5WithLMHeadModel) if False else () # is_tf_available() else ()
|
||||
|
||||
class TFT5ModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_token_type_ids=True,
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=16,
|
||||
type_sequence_label_size=2,
|
||||
initializer_range=0.02,
|
||||
num_labels=3,
|
||||
num_choices=4,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_token_type_ids = use_token_type_ids
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.type_sequence_label_size = type_sequence_label_size
|
||||
self.initializer_range = initializer_range
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = T5Config(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_t5_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFT5Model(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
inputs = [input_ids, input_mask]
|
||||
sequence_output, pooled_output = model(inputs)
|
||||
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output.numpy(),
|
||||
"pooled_output": pooled_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
|
||||
|
||||
|
||||
def create_and_check_t5_with_lm_head(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = TFT5WithLMHeadModel(config=config)
|
||||
inputs = {'input_ids': input_ids,
|
||||
'attention_mask': input_mask,
|
||||
'token_type_ids': token_type_ids}
|
||||
prediction_scores, = model(inputs)
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape),
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_mask,
|
||||
sequence_labels, token_labels, choice_labels) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5ModelTest.TFT5ModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=T5Config, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_t5_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_model(*config_and_inputs)
|
||||
|
||||
def test_with_lm_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_model_from_pretrained(self):
|
||||
cache_dir = "/tmp/transformers_test/"
|
||||
for model_name in ['t5-base']:
|
||||
model = TFT5Model.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
77
transformers/tests/tokenization_t5_test.py
Normal file
77
transformers/tests/tokenization_t5_test.py
Normal file
@ -0,0 +1,77 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 Google T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from transformers.tokenization_t5 import (T5Tokenizer, SPIECE_UNDERLINE)
|
||||
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fixtures/test_sentencepiece.model')
|
||||
|
||||
class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
|
||||
|
||||
tokenizer_class = T5Tokenizer
|
||||
|
||||
def setUp(self):
|
||||
super(T5TokenizationTest, self).setUp()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
tokenizer.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
def get_input_output_texts(self):
|
||||
input_text = u"This is a test"
|
||||
output_text = u"This is a test"
|
||||
return input_text, output_text
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB, keep_accents=True)
|
||||
|
||||
tokens = tokenizer.tokenize(u'This is a test')
|
||||
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
|
||||
|
||||
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
|
||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids, [8, 21, 84, 55, 24, 19, 7, 0,
|
||||
602, 347, 347, 347, 3, 12, 66,
|
||||
46, 72, 80, 6, 0, 4])
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
|
||||
u'or', u'n', SPIECE_UNDERLINE + u'in',
|
||||
SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
|
||||
SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
|
||||
u'<unk>', u'.'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user