This commit is contained in:
VictorSanh 2019-08-28 00:45:33 +00:00
parent 60c984da6c
commit a8ad83040d

View File

@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
attention_dropout=0.1,
activation='gelu',
initializer_range=0.02,
tie_weights=True,
tie_weights_=True,
**kwargs):
super(DilBertConfig, self).__init__(**kwargs)
@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
self.attention_dropout = attention_dropout
self.activation = activation
self.initializer_range = initializer_range
self.tie_weights = tie_weights
self.tie_weights_ = tie_weights_
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else:
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
# Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output)
output = (ffn_output,)
if self.output_attentions:
output = (sa_weights,) + output
return output
@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.encoder = DilBertModel(config)
self.dilbert = DilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights)
self.tie_weights_()
self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights_(self):
def tie_weights(self):
"""
Tying the weights of the vocabulary projection to the base token embeddings.
"""
if self.config.tie_weights:
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight
if self.config.tie_weights_:
self.vocab_projector.weight = self.dilbert.embeddings.word_embeddings.weight
def forward(self,
input_ids: torch.tensor,
attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None):
tfmr_output = self.encoder(input_ids=input_ids,
attention_mask=attention_mask)
hidden_states = tfmr_output[0] # (bs, seq_length, dim)
dlbrt_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
outputs = (prediction_logits, ) + tfmr_output[2:]
outputs = (prediction_logits, ) + dlbrt_output[2:]
if masked_lm_labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
masked_lm_labels.view(-1))