Merge pull request #2164 from huggingface/cleanup-configs

[SMALL BREAKING CHANGE] Cleaning up configuration classes - Adding Model Cards
This commit is contained in:
Thomas Wolf 2019-12-17 09:10:16 +01:00 committed by GitHub
commit f061606277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 683 additions and 437 deletions

View File

@ -33,6 +33,8 @@ class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model.
Arguments:
vocab_size: int
Number of tokens in the vocabulary.
max_pos: int
The maximum sequence length that this model will be used with.
enc_layer: int
@ -65,7 +67,7 @@ class BertAbsConfig(PretrainedConfig):
def __init__(
self,
vocab_size_or_config_json_file=30522,
vocab_size=30522,
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
@ -81,39 +83,17 @@ class BertAbsConfig(PretrainedConfig):
):
super(BertAbsConfig, self).__init__(**kwargs)
if self._input_is_path_to_json(vocab_size_or_config_json_file):
path_to_json = vocab_size_or_config_json_file
with open(path_to_json, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.max_pos = max_pos
self.vocab_size = vocab_size
self.max_pos = max_pos
self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout
self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout
self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
def _input_is_path_to_json(self, first_argument):
""" Checks whether the first argument passed to config
is the path to a JSON file that contains the config.
"""
is_python_2 = sys.version_info[0] == 2
if is_python_2:
return isinstance(first_argument, unicode)
else:
return isinstance(first_argument, str)
self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout

View File

@ -39,7 +39,7 @@ class XxxConfig(PretrainedConfig):
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XxxModel`.
vocab_size: Vocabulary size of `inputs_ids` in `XxxModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
@ -64,7 +64,7 @@ class XxxConfig(PretrainedConfig):
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=50257,
vocab_size=50257,
n_positions=1024,
n_ctx=1024,
n_embd=768,
@ -75,8 +75,6 @@ class XxxConfig(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
num_labels=1,
summary_type='cls_index',
summary_use_proj=True,
summary_activation=None,
@ -84,7 +82,7 @@ class XxxConfig(PretrainedConfig):
summary_first_dropout=0.1,
**kwargs):
super(XxxConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
@ -95,23 +93,11 @@ class XxxConfig(PretrainedConfig):
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
if isinstance(vocab_size_or_config_json_file, six.string_types):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):

View File

@ -111,7 +111,7 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XxxConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -109,7 +109,7 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = XxxConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Files and general utilities
from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
cached_path, add_start_docstrings, add_end_docstrings,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, MODEL_CARD_NAME,
is_tf_available, is_torch_available)
from .data import (is_sklearn_available,
@ -33,6 +33,9 @@ from .data import (is_sklearn_available,
if is_sklearn_available():
from .data import glue_compute_metrics, xnli_compute_metrics
# Model Cards
from .model_card import ModelCard
# Tokenizers
from .tokenization_utils import (PreTrainedTokenizer)
from .tokenization_auto import AutoTokenizer
@ -52,7 +55,7 @@ from .tokenization_t5 import T5Tokenizer
# Configurations
from .configuration_utils import PretrainedConfig
from .configuration_auto import AutoConfig
from .configuration_auto import AutoConfig, ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
@ -70,7 +73,7 @@ from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
if is_torch_available():
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
AutoModelWithLMHead)
AutoModelWithLMHead, ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
@ -128,7 +131,7 @@ if is_torch_available():
if is_tf_available():
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering,
TFAutoModelWithLMHead)
TFAutoModelWithLMHead, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings,
TFBertModel, TFBertForPreTraining,

View File

@ -37,7 +37,7 @@ class AlbertConfig(PretrainedConfig):
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30000,
vocab_size=30000,
embedding_size=128,
hidden_size=4096,
num_hidden_layers=12,
@ -83,7 +83,7 @@ class AlbertConfig(PretrainedConfig):
"""
super(AlbertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
@ -97,4 +97,4 @@ class AlbertConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.layer_norm_eps = layer_norm_eps

View File

@ -18,22 +18,40 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
from .configuration_bert import BertConfig
from .configuration_openai import OpenAIGPTConfig
from .configuration_gpt2 import GPT2Config
from .configuration_transfo_xl import TransfoXLConfig
from .configuration_xlnet import XLNetConfig
from .configuration_xlm import XLMConfig
from .configuration_roberta import RobertaConfig
from .configuration_distilbert import DistilBertConfig
from .configuration_ctrl import CTRLConfig
from .configuration_camembert import CamembertConfig
from .configuration_albert import AlbertConfig
from .configuration_t5 import T5Config
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict((key, value)
for pretrained_map in [
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items())
class AutoConfig(object):
r""":class:`~transformers.AutoConfig` is a generic configuration class
that will be instantiated as one of the configuration classes of the library

View File

@ -56,7 +56,7 @@ class BertConfig(PretrainedConfig):
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
@ -81,7 +81,7 @@ class BertConfig(PretrainedConfig):
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
@ -95,25 +95,15 @@ class BertConfig(PretrainedConfig):
layer_norm_eps=1e-12,
**kwargs):
super(BertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps

View File

@ -31,7 +31,7 @@ class CTRLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CTRLModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
vocab_size: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
dff: Size of the inner dimension of the FFN.
@ -52,7 +52,7 @@ class CTRLConfig(PretrainedConfig):
def __init__(
self,
vocab_size_or_config_json_file=246534,
vocab_size=246534,
n_positions=256,
n_ctx=256,
n_embd=1280,
@ -64,8 +64,6 @@ class CTRLConfig(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
num_labels=1,
summary_type='cls_index',
summary_use_proj=True,
summary_activation=None,
@ -76,7 +74,7 @@ class CTRLConfig(PretrainedConfig):
"""Constructs CTRLConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
vocab_size: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
dff: Size of the inner dimension of the FFN.
@ -94,8 +92,7 @@ class CTRLConfig(PretrainedConfig):
initializing all weight matrices.
"""
super(CTRLConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
@ -108,23 +105,11 @@ class CTRLConfig(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):

View File

@ -37,7 +37,7 @@ class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30522,
vocab_size=30522,
max_position_embeddings=512,
sinusoidal_pos_embds=False,
n_layers=6,
@ -53,31 +53,21 @@ class DistilBertConfig(PretrainedConfig):
seq_classif_dropout=0.2,
**kwargs):
super(DistilBertConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.sinusoidal_pos_embds = sinusoidal_pos_embds
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation = activation
self.initializer_range = initializer_range
self.tie_weights_ = tie_weights_
self.qa_dropout = qa_dropout
self.seq_classif_dropout = seq_classif_dropout
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.max_position_embeddings = max_position_embeddings
self.sinusoidal_pos_embds = sinusoidal_pos_embds
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation = activation
self.initializer_range = initializer_range
self.tie_weights_ = tie_weights_
self.qa_dropout = qa_dropout
self.seq_classif_dropout = seq_classif_dropout
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
@property
def hidden_size(self):
return self.dim

View File

@ -36,7 +36,7 @@ class GPT2Config(PretrainedConfig):
"""Configuration class to store the configuration of a `GPT2Model`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
@ -56,7 +56,7 @@ class GPT2Config(PretrainedConfig):
def __init__(
self,
vocab_size_or_config_json_file=50257,
vocab_size=50257,
n_positions=1024,
n_ctx=1024,
n_embd=768,
@ -67,8 +67,6 @@ class GPT2Config(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
num_labels=1,
summary_type='cls_index',
summary_use_proj=True,
summary_activation=None,
@ -79,7 +77,7 @@ class GPT2Config(PretrainedConfig):
"""Constructs GPT2Config.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
vocab_size: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
@ -96,37 +94,22 @@ class GPT2Config(PretrainedConfig):
initializing all weight matrices.
"""
super(GPT2Config, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
@property
def max_position_embeddings(self):

View File

@ -35,7 +35,7 @@ class OpenAIGPTConfig(PretrainedConfig):
Configuration class to store the configuration of a `OpenAIGPTModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
vocab_size: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
@ -58,7 +58,7 @@ class OpenAIGPTConfig(PretrainedConfig):
def __init__(
self,
vocab_size_or_config_json_file=40478,
vocab_size=40478,
n_positions=512,
n_ctx=512,
n_embd=768,
@ -71,8 +71,6 @@ class OpenAIGPTConfig(PretrainedConfig):
layer_norm_epsilon=1e-5,
initializer_range=0.02,
predict_special_tokens=True,
num_labels=1,
summary_type='cls_index',
summary_use_proj=True,
summary_activation=None,
@ -83,39 +81,24 @@ class OpenAIGPTConfig(PretrainedConfig):
"""Constructs OpenAIGPTConfig.
"""
super(OpenAIGPTConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.afn = afn
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.afn = afn
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
@property
def max_position_embeddings(self):

View File

@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=32128,
vocab_size=32128,
n_positions=512,
d_model=512,
d_kv=64,
@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0,
**kwargs):
super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
self.d_kv = d_kv
@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
if isinstance(vocab_size_or_config_json_file, six.string_types):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):
return self.n_positions

View File

@ -34,7 +34,7 @@ class TransfoXLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `TransfoXLModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
vocab_size: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file.
cutoffs: cutoffs for the adaptive softmax
d_model: Dimensionality of the model's hidden states.
d_embed: Dimensionality of the embeddings
@ -68,7 +68,7 @@ class TransfoXLConfig(PretrainedConfig):
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=267735,
vocab_size=267735,
cutoffs=[20000, 40000, 200000],
d_model=1024,
d_embed=1024,
@ -100,7 +100,7 @@ class TransfoXLConfig(PretrainedConfig):
"""Constructs TransfoXLConfig.
"""
super(TransfoXLConfig, self).__init__(**kwargs)
self.n_token = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.vocab_size = vocab_size
self.cutoffs = []
self.cutoffs.extend(cutoffs)
self.tie_weight = tie_weight
@ -133,27 +133,17 @@ class TransfoXLConfig(PretrainedConfig):
self.init_std = init_std
self.layer_norm_epsilon = layer_norm_epsilon
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
@property
def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len
@property
def vocab_size(self):
return self.n_token
def n_token(self): # Backward compatibility
return self.vocab_size
@vocab_size.setter
def vocab_size(self, value):
self.n_token = value
@n_token.setter
def n_token(self, value): # Backward compatibility
self.vocab_size = value
@property
def hidden_size(self):

View File

@ -49,8 +49,7 @@ class PretrainedConfig(object):
pretrained_config_archive_map = {}
def __init__(self, **kwargs):
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
# Attributes with defaults
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.output_past = kwargs.pop('output_past', True) # Not used by all models
@ -59,6 +58,22 @@ class PretrainedConfig(object):
self.pruned_heads = kwargs.pop('pruned_heads', {})
self.is_decoder = kwargs.pop('is_decoder', False)
# Fine-tuning task arguments
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
@ -136,10 +151,14 @@ class PretrainedConfig(object):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(pretrained_model_name_or_path, postfix=CONFIG_NAME)
# redirect to the cache, if necessary
try:
# Load from URL or cache if already cached
resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
# Load config
config = cls.from_json_file(resolved_config_file)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
@ -153,15 +172,18 @@ class PretrainedConfig(object):
config_file, CONFIG_NAME)
raise EnvironmentError(msg)
except json.JSONDecodeError:
msg = "Couldn't reach server at '{}' to download configuration file or " \
"configuration file is not a valid JSON file. " \
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
raise EnvironmentError(msg)
if resolved_config_file == config_file:
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.from_json_file(resolved_config_file)
if hasattr(config, 'pruned_heads'):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
@ -183,17 +205,15 @@ class PretrainedConfig(object):
@classmethod
def from_dict(cls, json_object):
"""Constructs a `Config` from a Python dictionary of parameters."""
config = cls(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
setattr(config, key, value)
return config
return cls(**json_object)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `Config` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
dict_obj = json.loads(text)
return cls(**dict_obj)
def __eq__(self, other):
return self.__dict__ == other.__dict__

View File

@ -42,7 +42,7 @@ class XLMConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `XLMModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`.
vocab_size: Vocabulary size of `inputs_ids` in `XLMModel`.
d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
@ -81,7 +81,7 @@ class XLMConfig(PretrainedConfig):
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=30145,
vocab_size=30145,
emb_dim=2048,
n_layers=12,
n_heads=16,
@ -103,9 +103,6 @@ class XLMConfig(PretrainedConfig):
unk_index=3,
mask_index=5,
is_encoder=True,
finetuning_task=None,
num_labels=2,
summary_type='first',
summary_use_proj=True,
summary_activation=None,
@ -117,56 +114,43 @@ class XLMConfig(PretrainedConfig):
"""Constructs XLMConfig.
"""
super(XLMConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_words = vocab_size_or_config_json_file
self.emb_dim = emb_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.gelu_activation = gelu_activation
self.sinusoidal_embeddings = sinusoidal_embeddings
self.causal = causal
self.asm = asm
self.n_langs = n_langs
self.use_lang_emb = use_lang_emb
self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index
self.eos_index = eos_index
self.pad_index = pad_index
self.unk_index = unk_index
self.mask_index = mask_index
self.is_encoder = is_encoder
self.max_position_embeddings = max_position_embeddings
self.embed_init_std = embed_init_std
self.init_std = init_std
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_proj_to_labels = summary_proj_to_labels
self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
self.vocab_size = vocab_size
self.emb_dim = emb_dim
self.n_layers = n_layers
self.n_heads = n_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.gelu_activation = gelu_activation
self.sinusoidal_embeddings = sinusoidal_embeddings
self.causal = causal
self.asm = asm
self.n_langs = n_langs
self.use_lang_emb = use_lang_emb
self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index
self.eos_index = eos_index
self.pad_index = pad_index
self.unk_index = unk_index
self.mask_index = mask_index
self.is_encoder = is_encoder
self.max_position_embeddings = max_position_embeddings
self.embed_init_std = embed_init_std
self.init_std = init_std
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_proj_to_labels = summary_proj_to_labels
self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
@property
def vocab_size(self):
return self.n_words
def n_words(self): # For backward compatibility
return self.vocab_size
@vocab_size.setter
def vocab_size(self, value):
self.n_words = value
@n_words.setter
def n_words(self, value): # For backward compatibility
self.vocab_size = value
@property
def hidden_size(self):

View File

@ -35,7 +35,7 @@ class XLNetConfig(PretrainedConfig):
"""Configuration class to store the configuration of a ``XLNetModel``.
Args:
vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``.
vocab_size: Vocabulary size of ``inputs_ids`` in ``XLNetModel``.
d_model: Size of the encoder layers and the pooler layer.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
@ -72,28 +72,22 @@ class XLNetConfig(PretrainedConfig):
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=32000,
vocab_size=32000,
d_model=1024,
n_layer=24,
n_head=16,
d_inner=4096,
max_position_embeddings=512,
ff_activation="gelu",
untie_r=True,
attn_type="bi",
initializer_range=0.02,
layer_norm_eps=1e-12,
dropout=0.1,
mem_len=None,
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False,
finetuning_task=None,
num_labels=2,
summary_type='last',
summary_use_proj=True,
summary_activation='tanh',
@ -104,58 +98,45 @@ class XLNetConfig(PretrainedConfig):
"""Constructs XLNetConfig.
"""
super(XLNetConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.n_layer = n_layer
self.n_head = n_head
assert d_model % n_head == 0
self.d_head = d_model // n_head
self.ff_activation = ff_activation
self.d_inner = d_inner
self.untie_r = untie_r
self.attn_type = attn_type
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
setattr(config, key, value)
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.d_model = d_model
self.n_layer = n_layer
self.n_head = n_head
assert d_model % n_head == 0
self.d_head = d_model // n_head
self.ff_activation = ff_activation
self.d_inner = d_inner
self.untie_r = untie_r
self.attn_type = attn_type
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.dropout = dropout
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.dropout = dropout
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
else:
raise ValueError("First argument must be either a vocabulary size (int)"
" or the path to a pretrained model config file (str)")
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout
self.start_n_top = start_n_top
self.end_n_top = end_n_top
@property
def max_position_embeddings(self):
return -1
@property
def vocab_size(self):
return self.n_token
def n_token(self): # Backward compatibility
return self.vocab_size
@vocab_size.setter
def vocab_size(self, value):
self.n_token = value
@n_token.setter
def n_token(self, value): # Backward compatibility
self.vocab_size = value
@property
def hidden_size(self):

View File

@ -46,7 +46,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
config = BertConfig(
vocab_size_or_config_json_file=50265,
vocab_size=50265,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
num_attention_heads=roberta.args.encoder_attention_heads,

View File

@ -72,7 +72,7 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF2_WEIGHTS_NAME = 'tf_model.h5'
TF_WEIGHTS_NAME = 'model.ckpt'
CONFIG_NAME = "config.json"
MODEL_CARD_NAME = "model_card.json"
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]

226
transformers/model_card.py Normal file
View File

@ -0,0 +1,226 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration base class and utilities."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import copy
import json
import logging
import os
from io import open
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .file_utils import CONFIG_NAME, MODEL_CARD_NAME, cached_path, is_remote_url, hf_bucket_url
logger = logging.getLogger(__name__)
class ModelCard(object):
r""" Model Card class.
Store model card as well as methods for loading/downloading/saving model cards.
Please read the following paper for details and explanation on the sections:
"Model Cards for Model Reporting"
by Margaret Mitchell, Simone Wu,
Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards.
Link: https://arxiv.org/abs/1810.03993
Note:
A model card can be loaded and saved to disk.
Parameters:
"""
def __init__(self, **kwargs):
# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
self.model_details = kwargs.pop('model_details', {})
self.intended_use = kwargs.pop('intended_use', {})
self.factors = kwargs.pop('factors', {})
self.metrics = kwargs.pop('metrics', {})
self.evaluation_data = kwargs.pop('evaluation_data', {})
self.training_data = kwargs.pop('training_data', {})
self.quantitative_analyses = kwargs.pop('quantitative_analyses', {})
self.ethical_considerations = kwargs.pop('ethical_considerations', {})
self.caveats_and_recommendations = kwargs.pop('caveats_and_recommendations', {})
# Open additional attributes
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`.
"""
if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
else:
output_model_card_file = save_directory_or_file
self.to_json_file(output_model_card_file)
logger.info("Model card saved in {}".format(output_model_card_file))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
Parameters:
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a mode card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
- a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/model_card.json``.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
card should be cached if the standard cache should not be used.
kwargs: (`optional`) dict: key/value pairs with which to update the ModelCard object after loading.
- The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model card file and override the cached version if it exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
return_unused_kwargs: (`optional`) bool:
- If False, then this function returns just the final model card object.
- If True, then this functions returns a tuple `(model card, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of kwargs which has not been used to update `ModelCard` and is otherwise ignored.
Examples::
model_card = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
model_card = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
model_card = ModelCard.from_pretrained('./test/saved_model/model_card.json')
model_card = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
"""
cache_dir = kwargs.pop('cache_dir', None)
force_download = kwargs.pop('force_download', False)
resume_download = kwargs.pop('resume_download', False)
proxies = kwargs.pop('proxies', None)
return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files but with a different suffix (model_card.json)
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, postfix=MODEL_CARD_NAME)
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, force_download=force_download,
proxies=proxies, resume_download=resume_download)
if resolved_model_card_file == model_card_file:
logger.info("loading model card file {}".format(model_card_file))
else:
logger.info("loading model card file {} from cache at {}".format(
model_card_file, resolved_model_card_file))
# Load model card
model_card = cls.from_json_file(resolved_model_card_file)
except EnvironmentError:
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.warning("Couldn't reach server at '{}' to download model card file.".format(
model_card_file))
else:
logger.warning("Model name '{}' was not found in model name list ({}). " \
"We assumed '{}' was a path or url to a model card file named {} or " \
"a directory containing such a file but couldn't find any such file at this path or url.".format(
pretrained_model_name_or_path,
', '.join(ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
model_card_file, MODEL_CARD_NAME))
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
model_card = cls()
except json.JSONDecodeError:
logger.warning("Couldn't reach server at '{}' to download model card file or "
"model card file is not a valid JSON file. "
"Please check network or file content here: {}.".format(model_card_file, resolved_model_card_file))
logger.warning("Creating an empty model card.")
# We fall back on creating an empty model card
model_card = cls()
# Update model card with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(model_card, key):
setattr(model_card, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model card: %s", str(model_card))
if return_unused_kwargs:
return model_card, kwargs
else:
return model_card
@classmethod
def from_dict(cls, json_object):
"""Constructs a `ModelCard` from a Python dictionary of parameters."""
return cls(**json_object)
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `ModelCard` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
dict_obj = json.loads(text)
return cls(**dict_obj)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())

View File

@ -18,18 +18,18 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering
from .modeling_t5 import T5Model, T5WithLMHeadModel
from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_utils import PreTrainedModel, SequenceSummary
@ -38,6 +38,24 @@ from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
for pretrained_map in [
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items())
class AutoModel(object):
r"""
:class:`~transformers.AutoModel` is a generic model class

View File

@ -634,6 +634,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
"""
def __init__(self, config):
super(GPT2DoubleHeadsModel, self).__init__(config)
config.num_labels = 1
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config)

View File

@ -590,6 +590,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
config.num_labels = 1
self.transformer = OpenAIGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config)

View File

@ -18,22 +18,40 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
from .modeling_tf_bert import TFBertModel, TFBertForMaskedLM, TFBertForSequenceClassification, TFBertForQuestionAnswering
from .modeling_tf_openai import TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel
from .modeling_tf_gpt2 import TFGPT2Model, TFGPT2LMHeadModel
from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, TFXLNetForQuestionAnsweringSimple
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel
from .modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel
from .modeling_tf_bert import TFBertModel, TFBertForMaskedLM, TFBertForSequenceClassification, TFBertForQuestionAnswering, TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_openai import TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_gpt2 import TFGPT2Model, TFGPT2LMHeadModel, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_transfo_xl import TFTransfoXLModel, TFTransfoXLLMHeadModel, TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_xlnet import TFXLNetModel, TFXLNetLMHeadModel, TFXLNetForSequenceClassification, TFXLNetForQuestionAnsweringSimple, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_xlm import TFXLMModel, TFXLMWithLMHeadModel, TFXLMForSequenceClassification, TFXLMForQuestionAnsweringSimple, TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_roberta import TFRobertaModel, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_distilbert import TFDistilBertModel, TFDistilBertForQuestionAnswering, TFDistilBertForMaskedLM, TFDistilBertForSequenceClassification, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_albert import TFAlbertModel, TFAlbertForMaskedLM, TFAlbertForSequenceClassification, TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_tf_t5 import TFT5Model, TFT5WithLMHeadModel, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict((key, value)
for pretrained_map in [
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items())
class TFAutoModel(object):
r"""
:class:`~transformers.TFAutoModel` is a generic model class
@ -144,6 +162,8 @@ class TFAutoModel(object):
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
@ -280,6 +300,8 @@ class TFAutoModelWithLMHead(object):
return TFT5WithLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
return TFAlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
@ -407,6 +429,8 @@ class TFAutoModelForSequenceClassification(object):
"""
if 'distilbert' in pretrained_model_name_or_path:
return TFDistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'albert' in pretrained_model_name_or_path:
return TFAlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return TFRobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path:

View File

@ -574,6 +574,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
"""
def __init__(self, config, *inputs, **kwargs):
super(TFGPT2DoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
config.num_labels = 1
self.transformer = TFGPT2MainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')

View File

@ -538,6 +538,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
"""
def __init__(self, config, *inputs, **kwargs):
super(TFOpenAIGPTDoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
config.num_labels = 1
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')

View File

@ -353,7 +353,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.n_token = config.n_token
self.n_token = config.vocab_size
self.d_embed = config.d_embed
self.d_model = config.d_model
@ -361,7 +361,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
self.d_head = config.d_head
self.untie_r = config.untie_r
self.word_emb = TFAdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
self.word_emb = TFAdaptiveEmbedding(config.vocab_size, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val, init_std=config.init_std, name='word_emb')
self.drop = tf.keras.layers.Dropout(config.dropout)
@ -729,7 +729,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
raise NotImplementedError
# use adaptive softmax (including standard softmax)
else:
self.crit = TFAdaptiveSoftmaxMask(config.n_token, config.d_embed, config.d_model,
self.crit = TFAdaptiveSoftmaxMask(config.vocab_size, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val, name='crit')
def reset_length(self, tgt_len, ext_len, mem_len):

View File

@ -25,15 +25,15 @@ import tensorflow as tf
from .modeling_tf_utils import shape_list
class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False, **kwargs):
super(TFAdaptiveSoftmaxMask, self).__init__(**kwargs)
self.n_token = n_token
self.vocab_size = vocab_size
self.d_embed = d_embed
self.d_proj = d_proj
self.cutoffs = cutoffs + [n_token]
self.cutoffs = cutoffs + [vocab_size]
self.cutoff_ends = [0] + self.cutoffs
self.div_val = div_val
@ -66,11 +66,11 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
self.out_projs.append(weight)
else:
self.out_projs.append(None)
weight = self.add_weight(shape=(self.n_token, self.d_embed,),
weight = self.add_weight(shape=(self.vocab_size, self.d_embed,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._weight'.format(i))
bias = self.add_weight(shape=(self.n_token,),
bias = self.add_weight(shape=(self.vocab_size,),
initializer='zeros',
trainable=True,
name='out_layers_._{}_._bias'.format(i))
@ -114,7 +114,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
hidden, target = inputs
head_logprob = 0
if self.n_clusters == 0:
softmax_b = tf.get_variable('bias', [n_token], initializer=tf.zeros_initializer())
softmax_b = tf.get_variable('bias', [self.config.vocab_size], initializer=tf.zeros_initializer())
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
if target is not None:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)

View File

@ -366,7 +366,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
self.use_bfloat16 = config.use_bfloat16
self.initializer_range = config.initializer_range
self.word_embedding = TFSharedEmbeddings(config.n_token, config.d_model, initializer_range=config.initializer_range, name='word_embedding')
self.word_embedding = TFSharedEmbeddings(config.vocab_size, config.d_model, initializer_range=config.initializer_range, name='word_embedding')
self.layer = [TFXLNetLayer(config, name='layer_._{}'.format(i)) for i in range(config.n_layer)]
self.dropout = tf.keras.layers.Dropout(config.dropout)

View File

@ -592,14 +592,14 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.n_token = config.n_token
self.n_token = config.vocab_size
self.d_embed = config.d_embed
self.d_model = config.d_model
self.n_head = config.n_head
self.d_head = config.d_head
self.word_emb = AdaptiveEmbedding(config.n_token, config.d_embed, config.d_model, config.cutoffs,
self.word_emb = AdaptiveEmbedding(config.vocab_size, config.d_embed, config.d_model, config.cutoffs,
div_val=config.div_val)
self.drop = nn.Dropout(config.dropout)
@ -836,11 +836,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.sample_softmax = config.sample_softmax
# use sampled softmax
if config.sample_softmax > 0:
self.out_layer = nn.Linear(config.d_model, config.n_token)
self.sampler = LogUniformSampler(config.n_token, config.sample_softmax)
self.out_layer = nn.Linear(config.d_model, config.vocab_size)
self.sampler = LogUniformSampler(config.vocab_size, config.sample_softmax)
# use adaptive softmax (including standard softmax)
else:
self.crit = ProjectedAdaptiveLogSoftmax(config.n_token, config.d_embed, config.d_model,
self.crit = ProjectedAdaptiveLogSoftmax(config.vocab_size, config.d_embed, config.d_model,
config.cutoffs, div_val=config.div_val)
self.init_weights()

View File

@ -609,7 +609,7 @@ class XLNetModel(XLNetPreTrainedModel):
self.clamp_len = config.clamp_len
self.n_layer = config.n_layer
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, config.d_model))
self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
@ -940,7 +940,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self.same_length = config.same_length
self.transformer = XLNetModel(config)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
self.init_weights()

View File

@ -16,15 +16,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import shutil
import json
import random
import uuid
import tempfile
import unittest
import logging
from .tokenization_tests_commons import TemporaryDirectory
class ConfigTester(object):
@ -48,16 +45,28 @@ class ConfigTester(object):
def create_and_test_config_to_json_file(self):
config_first = self.config_class(**self.inputs_dict)
json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json")
config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path)
os.remove(json_file_path)
with TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "config.json")
config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_from_and_save_pretrained(self):
config_first = self.config_class(**self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
config_first.save_pretrained(tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_common_tests(self):
self.create_and_test_config_common_properties()
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,89 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import json
import unittest
from transformers.model_card import ModelCard
from .tokenization_tests_commons import TemporaryDirectory
class ModelCardTester(unittest.TestCase):
def setUp(self):
self.inputs_dict = {'model_details': {
'Organization': 'testing',
'Model date': 'today',
'Model version': 'v2.1, Developed by Test Corp in 2019.',
'Architecture': 'Convolutional Neural Network.',
},
'metrics': 'BLEU and ROUGE-1',
'evaluation_data':{
'Datasets':{
'BLEU': 'My-great-dataset-v1',
'ROUGE-1': 'My-short-dataset-v2.1',
},
'Preprocessing': 'See details on https://arxiv.org/pdf/1810.03993.pdf'
},
'training_data':{
'Dataset': 'English Wikipedia dump dated 2018-12-01',
'Preprocessing': 'Using SentencePiece vocabulary of size 52k tokens. See details on https://arxiv.org/pdf/1810.03993.pdf'
},
'quantitative_analyses': {
'BLEU': 55.1,
'ROUGE-1': 76,
},
}
def test_model_card_common_properties(self):
model_card = ModelCard.from_dict(self.inputs_dict)
self.assertTrue(hasattr(model_card, 'model_details'))
self.assertTrue(hasattr(model_card, 'intended_use'))
self.assertTrue(hasattr(model_card, 'factors'))
self.assertTrue(hasattr(model_card, 'metrics'))
self.assertTrue(hasattr(model_card, 'evaluation_data'))
self.assertTrue(hasattr(model_card, 'training_data'))
self.assertTrue(hasattr(model_card, 'quantitative_analyses'))
self.assertTrue(hasattr(model_card, 'ethical_considerations'))
self.assertTrue(hasattr(model_card, 'caveats_and_recommendations'))
def test_model_card_to_json_string(self):
model_card = ModelCard.from_dict(self.inputs_dict)
obj = json.loads(model_card.to_json_string())
for key, value in self.inputs_dict.items():
self.assertEqual(obj[key], value)
def test_model_card_to_json_file(self):
model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, u"model_card.json")
model_card_first.to_json_file(filename)
model_card_second = ModelCard.from_json_file(filename)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
def test_model_card_from_and_save_pretrained(self):
model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname:
model_card_first.save_pretrained(tmpdirname)
model_card_second = ModelCard.from_pretrained(tmpdirname)
self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
if __name__ == "__main__":
unittest.main()

View File

@ -110,7 +110,7 @@ class AlbertModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = AlbertConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -109,7 +109,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -676,7 +676,7 @@ class CommonTestCases:
mc_token_ids = ids_tensor([self.batch_size, self.n_choices], self.seq_length)
config = self.config_class(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_positions=self.n_positions,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,

View File

@ -114,7 +114,7 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = CTRLConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,

View File

@ -105,7 +105,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = DistilBertConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
dim=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,

View File

@ -110,7 +110,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,

View File

@ -98,7 +98,7 @@ class OpenAIGPTModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,

View File

@ -106,7 +106,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = RobertaConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_positions=self.n_positions,
d_model=self.hidden_size,
d_ff=self.d_ff,

View File

@ -118,7 +118,7 @@ class TFAlbertModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = AlbertConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -114,7 +114,7 @@ class TFBertModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -112,7 +112,7 @@ class TFCTRLModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = CTRLConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,

View File

@ -107,7 +107,7 @@ class TFDistilBertModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = DistilBertConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
dim=self.hidden_size,
n_layers=self.num_hidden_layers,
n_heads=self.num_attention_heads,

View File

@ -115,7 +115,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,

View File

@ -114,7 +114,7 @@ class TFOpenAIGPTModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,

View File

@ -109,7 +109,7 @@ class TFRobertaModelTest(TFCommonTestCases.TFCommonModelTester):
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = RobertaConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,

View File

@ -87,7 +87,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = T5Config(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_positions=self.n_positions,
d_model=self.hidden_size,
d_ff=self.d_ff,

View File

@ -92,7 +92,7 @@ class TFTransfoXLModelTest(TFCommonTestCases.TFCommonModelTester):
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
cutoffs=self.cutoffs,

View File

@ -125,7 +125,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)
config = XLMConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_special=self.n_special,
emb_dim=self.hidden_size,
n_layers=self.num_hidden_layers,

View File

@ -64,7 +64,6 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
num_attention_heads=4,
d_inner=128,
num_hidden_layers=5,
max_position_embeddings=10,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
@ -88,7 +87,6 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
self.num_attention_heads = num_attention_heads
self.d_inner = d_inner
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
@ -122,13 +120,12 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
is_impossible_labels = ids_tensor([self.batch_size], 2, dtype=tf.float32)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
d_model=self.hidden_size,
n_head=self.num_attention_heads,
d_inner=self.d_inner,
n_layer=self.num_hidden_layers,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,

View File

@ -91,7 +91,7 @@ class TransfoXLModelTest(CommonTestCases.CommonModelTester):
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
cutoffs=self.cutoffs,

View File

@ -121,7 +121,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
is_impossible_labels = ids_tensor([self.batch_size], 2).float()
config = XLMConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_special=self.n_special,
emb_dim=self.hidden_size,
n_layers=self.num_hidden_layers,

View File

@ -60,7 +60,6 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
num_attention_heads=4,
d_inner=128,
num_hidden_layers=5,
max_position_embeddings=10,
type_sequence_label_size=2,
untie_r=True,
bi_data=False,
@ -84,7 +83,6 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
self.num_attention_heads = num_attention_heads
self.d_inner = d_inner
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
@ -116,13 +114,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
d_model=self.hidden_size,
n_head=self.num_attention_heads,
d_inner=self.d_inner,
n_layer=self.num_hidden_layers,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
same_length=self.same_length,