mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
implemented BertForQuestionAnswering
This commit is contained in:
parent
5383fca458
commit
c0065af6cb
@ -404,12 +404,12 @@ class BertForSequenceClassification(nn.Module):
|
|||||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||||
|
|
||||||
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
|
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||||
|
|
||||||
num_labels = 2
|
num_labels = 2
|
||||||
|
|
||||||
model = modeling.BertModel(config, num_labels)
|
model = BertForSequenceClassification(config, num_labels)
|
||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@ -420,7 +420,7 @@ class BertForSequenceClassification(nn.Module):
|
|||||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||||
|
|
||||||
def init_weights(m):
|
def init_weights(m):
|
||||||
if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
|
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||||
print("Initializing {}".format(m))
|
print("Initializing {}".format(m))
|
||||||
# Slight difference here with the TF version which uses truncated_normal
|
# Slight difference here with the TF version which uses truncated_normal
|
||||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
@ -438,3 +438,52 @@ class BertForSequenceClassification(nn.Module):
|
|||||||
return loss, logits
|
return loss, logits
|
||||||
else:
|
else:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
class BertForQuestionAnswering(nn.Module):
|
||||||
|
"""BERT model for Question Answering (span extraction).
|
||||||
|
This module is composed of the BERT model with linear layers on top of
|
||||||
|
the sequence output.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```python
|
||||||
|
# Already been converted into WordPiece token ids
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||||
|
|
||||||
|
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||||
|
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||||
|
|
||||||
|
model = BertForQuestionAnswering(config)
|
||||||
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BertForQuestionAnswering, self).__init__()
|
||||||
|
self.bert = BertModel(config)
|
||||||
|
# TODO check if it's normal there is no dropout on SQuAD in the TF version
|
||||||
|
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, (nn.Linear, nn.Embedding)):
|
||||||
|
print("Initializing {}".format(m))
|
||||||
|
# Slight difference here with the TF version which uses truncated_normal
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
m.weight.data.normal_(config.initializer_range)
|
||||||
|
self.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
|
||||||
|
all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
|
||||||
|
sequence_output = all_encoder_layers[-1]
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
return total_loss, (start_logits, end_logits)
|
||||||
|
else:
|
||||||
|
return start_logits, end_logits
|
||||||
|
Loading…
Reference in New Issue
Block a user