mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bugs
This commit is contained in:
parent
60c984da6c
commit
a8ad83040d
@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
|
||||
attention_dropout=0.1,
|
||||
activation='gelu',
|
||||
initializer_range=0.02,
|
||||
tie_weights=True,
|
||||
tie_weights_=True,
|
||||
**kwargs):
|
||||
super(DilBertConfig, self).__init__(**kwargs)
|
||||
|
||||
@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation = activation
|
||||
self.initializer_range = initializer_range
|
||||
self.tie_weights = tie_weights
|
||||
self.tie_weights_ = tie_weights_
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
|
||||
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
|
||||
if self.output_attentions:
|
||||
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
||||
else:
|
||||
sa_output = sa_output[0]
|
||||
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
|
||||
|
||||
# Feed Forward Network
|
||||
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
|
||||
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
||||
|
||||
output = (ffn_output)
|
||||
output = (ffn_output,)
|
||||
if self.output_attentions:
|
||||
output = (sa_weights,) + output
|
||||
return output
|
||||
@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.encoder = DilBertModel(config)
|
||||
self.dilbert = DilBertModel(config)
|
||||
self.vocab_transform = nn.Linear(config.dim, config.dim)
|
||||
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights_()
|
||||
self.tie_weights()
|
||||
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights_(self):
|
||||
def tie_weights(self):
|
||||
"""
|
||||
Tying the weights of the vocabulary projection to the base token embeddings.
|
||||
"""
|
||||
if self.config.tie_weights:
|
||||
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight
|
||||
if self.config.tie_weights_:
|
||||
self.vocab_projector.weight = self.dilbert.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
masked_lm_labels: torch.tensor = None):
|
||||
tfmr_output = self.encoder(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
hidden_states = tfmr_output[0] # (bs, seq_length, dim)
|
||||
dlbrt_output = self.dilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
|
||||
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
|
||||
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
|
||||
|
||||
outputs = (prediction_logits, ) + tfmr_output[2:]
|
||||
outputs = (prediction_logits, ) + dlbrt_output[2:]
|
||||
if masked_lm_labels is not None:
|
||||
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
|
||||
masked_lm_labels.view(-1))
|
||||
|
Loading…
Reference in New Issue
Block a user