implemented BertForQuestionAnswering

This commit is contained in:
thomwolf 2018-11-02 03:04:34 +01:00
parent 5383fca458
commit c0065af6cb

View File

@ -404,12 +404,12 @@ class BertForSequenceClassification(nn.Module):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
num_labels = 2
model = modeling.BertModel(config, num_labels)
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
@ -420,7 +420,7 @@ class BertForSequenceClassification(nn.Module):
self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
if isinstance(m, (nn.Linear, nn.Embedding)):
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
@ -438,3 +438,52 @@ class BertForSequenceClassification(nn.Module):
return loss, logits
else:
return logits
class BertForQuestionAnswering(nn.Module):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with linear layers on top of
the sequence output.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForQuestionAnswering(config)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__()
self.bert = BertModel(config)
# TODO check if it's normal there is no dropout on SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Embedding)):
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(config.initializer_range)
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = all_encoder_layers[-1]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
if start_positions is not None and end_positions is not None:
loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss, (start_logits, end_logits)
else:
return start_logits, end_logits