mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
fix bugs
This commit is contained in:
parent
42968138c8
commit
60c984da6c
@ -40,7 +40,8 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_dilbert import (DilBertconfig, DilBertForMaskedLM, DilBertModel, DilBertForSequenceClassification,
|
||||
from .modeling_dilbert import (DilBertConfig, DilBertForMaskedLM, DilBertModel,
|
||||
DilBertForSequenceClassification, DilBertForQuestionAnswering,
|
||||
DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
|
||||
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
|
||||
|
@ -45,7 +45,7 @@ DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class DilBertconfig(PretrainedConfig):
|
||||
class DilBertConfig(PretrainedConfig):
|
||||
pretrained_config_archive_map = DILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
@ -62,7 +62,7 @@ class DilBertconfig(PretrainedConfig):
|
||||
initializer_range=0.02,
|
||||
tie_weights=True,
|
||||
**kwargs):
|
||||
super(DilBertconfig, self).__init__(**kwargs)
|
||||
super(DilBertConfig, self).__init__(**kwargs)
|
||||
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
@ -77,6 +77,7 @@ class DilBertconfig(PretrainedConfig):
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.dim = dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation = activation
|
||||
@ -341,7 +342,7 @@ class DilBertPreTrainedModel(PreTrainedModel):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for downloading and loading pretrained models.
|
||||
"""
|
||||
config_class = DilBertconfig
|
||||
config_class = DilBertConfig
|
||||
pretrained_model_archive_map = DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "dilbert"
|
||||
@ -370,7 +371,7 @@ DILBERT_START_DOCSTRING = r"""
|
||||
For more information on DilBERT, you should check TODO(Victor): Link to Medium
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.DilBertconfig`): Model configuration class with all the parameters of the model.
|
||||
config (:class:`~pytorch_transformers.DilBertConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
@ -391,18 +392,7 @@ DILBERT_INPUTS_DOCSTRING = r"""
|
||||
@add_start_docstrings("The bare DilBERT encoder/transformer outputing raw hidden-states without any specific head on top.",
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertModel(DilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(DilBertModel, self).__init__(config)
|
||||
|
||||
self.embeddings = Embeddings(config) # Embeddings
|
||||
self.transformer = Transformer(config) # Encoder
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None):
|
||||
"""
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
@ -422,7 +412,18 @@ class DilBertModel(DilBertPreTrainedModel):
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if output_attentions=True
|
||||
"""
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertModel, self).__init__(config)
|
||||
|
||||
self.embeddings = Embeddings(config) # Embeddings
|
||||
self.transformer = Transformer(config) # Encoder
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids) # (bs, seq_length)
|
||||
|
||||
@ -438,33 +439,7 @@ class DilBertModel(DilBertPreTrainedModel):
|
||||
@add_start_docstrings("""DilBert Model with a `masked language modeling` head on top. """,
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertForMaskedLM(DilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(DilBertForMaskedLM, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.encoder = DilBertModel(config)
|
||||
self.vocab_transform = nn.Linear(config.dim, config.dim)
|
||||
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights()
|
||||
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights_(self):
|
||||
"""
|
||||
Tying the weights of the vocabulary projection to the base token embeddings.
|
||||
"""
|
||||
if self.config.tie_weights:
|
||||
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
masked_lm_labels: torch.tensor = None):
|
||||
"""
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
@ -487,7 +462,33 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if `output_attentions`=True
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertForMaskedLM, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.encoder = DilBertModel(config)
|
||||
self.vocab_transform = nn.Linear(config.dim, config.dim)
|
||||
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
|
||||
self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
self.tie_weights_()
|
||||
|
||||
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
|
||||
|
||||
def tie_weights_(self):
|
||||
"""
|
||||
Tying the weights of the vocabulary projection to the base token embeddings.
|
||||
"""
|
||||
if self.config.tie_weights:
|
||||
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
masked_lm_labels: torch.tensor = None):
|
||||
tfmr_output = self.encoder(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
hidden_states = tfmr_output[0] # (bs, seq_length, dim)
|
||||
@ -508,22 +509,7 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
|
||||
the pooled output) e.g. for GLUE tasks. """,
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertForSequenceClassification(DilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(DilBertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.dilbert = DilBertModel(config)
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
self.classifier = nn.Linear(config.dim, config.num_labels)
|
||||
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
labels: torch.tensor = None):
|
||||
"""
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
@ -546,7 +532,22 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if `output_attentions`=True
|
||||
"""
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.dilbert = DilBertModel(config)
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
self.classifier = nn.Linear(config.dim, config.num_labels)
|
||||
self.dropout = nn.Dropout(config.seq_classif_dropout)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
labels: torch.tensor = None):
|
||||
dilbert_output = self.dilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
pooled_output = dilbert_output[1] # (bs, dim)
|
||||
@ -571,22 +572,7 @@ class DilBertForSequenceClassification(DilBertPreTrainedModel):
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
DILBERT_START_DOCSTRING, DILBERT_INPUTS_DOCSTRING)
|
||||
class DilBertForQuestionAnswering(DilBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(DilBertForQuestionAnswering, self).__init__(config)
|
||||
|
||||
self.dilbert = DilBertModel(config)
|
||||
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
|
||||
assert config.num_labels == 2
|
||||
self.dropout = nn.Dropout(config.qa_dropout)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
start_positions: torch.tensor = None,
|
||||
end_positions: torch.tensor = None):
|
||||
"""
|
||||
r"""
|
||||
Parameters
|
||||
----------
|
||||
input_ids: torch.tensor(bs, seq_length)
|
||||
@ -619,7 +605,22 @@ class DilBertForQuestionAnswering(DilBertPreTrainedModel):
|
||||
all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
|
||||
Tuple of length n_layers with the attention weights from each layer
|
||||
Optional: only if `output_attentions`=True
|
||||
"""
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(DilBertForQuestionAnswering, self).__init__(config)
|
||||
|
||||
self.dilbert = DilBertModel(config)
|
||||
self.qa_outputs = nn.Linear(config.dim, config.num_labels)
|
||||
assert config.num_labels == 2
|
||||
self.dropout = nn.Dropout(config.qa_dropout)
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.tensor,
|
||||
attention_mask: torch.tensor = None,
|
||||
start_positions: torch.tensor = None,
|
||||
end_positions: torch.tensor = None):
|
||||
dilbert_output = self.dilbert(input_ids=input_ids,
|
||||
attention_mask=attention_mask)
|
||||
hidden_states = dilbert_output[0] # (bs, max_query_len, dim)
|
||||
|
Loading…
Reference in New Issue
Block a user