This commit is contained in:
VictorSanh 2019-08-27 22:25:55 +00:00
parent 42968138c8
commit 60c984da6c
2 changed files with 81 additions and 79 deletions

View File

@ -40,7 +40,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_dilbert import (DilBertconfig, DilBertForMaskedLM, DilBertModel, DilBertForSequenceClassification,
from .modeling_dilbert import (DilBertConfig, DilBertForMaskedLM, DilBertModel,
DilBertForSequenceClassification, DilBertForQuestionAnswering,
DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)

View File

@ -45,7 +45,7 @@ DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
}
class DilBertconfig(PretrainedConfig):
class DilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
@ -62,7 +62,7 @@ class DilBertconfig(PretrainedConfig):
initializer_range=0.02,
tie_weights=True,
**kwargs):
super(DilBertconfig, self).__init__(**kwargs)
super(DilBertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
@ -77,6 +77,7 @@ class DilBertconfig(PretrainedConfig):
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = dim
self.hidden_dim = hidden_dim
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation = activation
@ -341,7 +342,7 @@ class DilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = DilBertconfig
config_class = DilBertConfig
pretrained_model_archive_map = DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "dilbert"
@ -370,7 +371,7 @@ DILBERT_START_DOCSTRING = r"""
For more information on DilBERT, you should check TODO(Victor): Link to Medium
Parameters:
config (:class:`~pytorch_transformers.DilBertconfig`): Model configuration class with all the parameters of the model.
config (:class:`~pytorch_transformers.DilBertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
@ -391,18 +392,7 @@ DILBERT_INPUTS_DOCSTRING = r"""
@add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertModel(DilBertPreTrainedModel):
def __init__(self, config):
super(DilBertModel, self).__init__(config)
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None):
"""
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
@ -422,7 +412,18 @@ class DilBertModel(DilBertPreTrainedModel):
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
"""
def __init__(self, config):
super(DilBertModel, self).__init__(config)
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
@ -438,33 +439,7 @@ class DilBertModel(DilBertPreTrainedModel):
@add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForMaskedLM(DilBertPreTrainedModel):
def __init__(self, config):
super(DilBertForMaskedLM, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.encoder = DilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights)
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights_(self):
"""
Tying the weights of the vocabulary projection to the base token embeddings.
"""
if self.config.tie_weights:
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None):
"""
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
@ -487,7 +462,33 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
def __init__(self, config):
super(DilBertForMaskedLM, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.encoder = DilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights)
self.tie_weights_()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights_(self):
"""
Tying the weights of the vocabulary projection to the base token embeddings.
"""
if self.config.tie_weights:
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None):
tfmr_output = self.encoder(input_ids=input_ids,
attention_mask=attention_mask)
hidden_states = tfmr_output[0] # (bs, seq_length, dim)
@ -508,22 +509,7 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
the pooled output) e.g. for GLUE tasks. """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForSequenceClassification(DilBertPreTrainedModel):
def __init__(self, config):
super(DilBertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.dilbert = DilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout)
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
labels: torch.tensor = None):
"""
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
@ -546,7 +532,22 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
"""
def __init__(self, config):
super(DilBertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.dilbert = DilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.classifier = nn.Linear(config.dim, config.num_labels)
self.dropout = nn.Dropout(config.seq_classif_dropout)
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
labels: torch.tensor = None):
dilbert_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask)
pooled_output = dilbert_output[1] # (bs, dim)
@ -571,22 +572,7 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
the hidden-states output to compute `span start logits` and `span end logits`). """,
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
class DilBertForQuestionAnswering(DilBertPreTrainedModel):
def __init__(self, config):
super(DilBertForQuestionAnswering, self).__init__(config)
self.dilbert = DilBertModel(config)
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout)
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
start_positions: torch.tensor = None,
end_positions: torch.tensor = None):
"""
r"""
Parameters
----------
input_ids: torch.tensor(bs, seq_length)
@ -619,7 +605,22 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
Tuple of length n_layers with the attention weights from each layer
Optional: only if `output_attentions`=True
"""
"""
def __init__(self, config):
super(DilBertForQuestionAnswering, self).__init__(config)
self.dilbert = DilBertModel(config)
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
assert config.num_labels == 2
self.dropout = nn.Dropout(config.qa_dropout)
self.apply(self.init_weights)
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
start_positions: torch.tensor = None,
end_positions: torch.tensor = None):
dilbert_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask)
hidden_states = dilbert_output[0] # (bs, max_query_len, dim)