mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
implemented BertForQuestionAnswering
This commit is contained in:
parent
5383fca458
commit
c0065af6cb
@ -404,12 +404,12 @@ class BertForSequenceClassification(nn.Module):
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
num_labels = 2
|
||||
|
||||
model = modeling.BertModel(config, num_labels)
|
||||
model = BertForSequenceClassification(config, num_labels)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
@ -420,7 +420,7 @@ class BertForSequenceClassification(nn.Module):
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
@ -438,3 +438,52 @@ class BertForSequenceClassification(nn.Module):
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
||||
|
||||
class BertForQuestionAnswering(nn.Module):
|
||||
"""BERT model for Question Answering (span extraction).
|
||||
This module is composed of the BERT model with linear layers on top of
|
||||
the sequence output.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
|
||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
|
||||
model = BertForQuestionAnswering(config)
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertForQuestionAnswering, self).__init__()
|
||||
self.bert = BertModel(config)
|
||||
# TODO check if it's normal there is no dropout on SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
print("Initializing {}".format(m))
|
||||
# Slight difference here with the TF version which uses truncated_normal
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
m.weight.data.normal_(config.initializer_range)
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
|
||||
all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
return total_loss, (start_logits, end_logits)
|
||||
else:
|
||||
return start_logits, end_logits
|
||||
|
Loading…
Reference in New Issue
Block a user