diff --git a/modeling_pytorch.py b/modeling_pytorch.py index be592eb140f..0dc95269c9e 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -404,12 +404,12 @@ class BertForSequenceClassification(nn.Module): input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) - config = modeling.BertConfig(vocab_size=32000, hidden_size=512, + config = BertConfig(vocab_size=32000, hidden_size=512, num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) num_labels = 2 - model = modeling.BertModel(config, num_labels) + model = BertForSequenceClassification(config, num_labels) logits = model(input_ids, token_type_ids, input_mask) ``` """ @@ -420,7 +420,7 @@ class BertForSequenceClassification(nn.Module): self.classifier = nn.Linear(config.hidden_size, num_labels) def init_weights(m): - if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding): + if isinstance(m, (nn.Linear, nn.Embedding)): print("Initializing {}".format(m)) # Slight difference here with the TF version which uses truncated_normal # cf https://github.com/pytorch/pytorch/pull/5617 @@ -438,3 +438,52 @@ class BertForSequenceClassification(nn.Module): return loss, logits else: return logits + +class BertForQuestionAnswering(nn.Module): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with linear layers on top of + the sequence output. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = BertForQuestionAnswering(config) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__() + self.bert = BertModel(config) + # TODO check if it's normal there is no dropout on SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + def init_weights(m): + if isinstance(m, (nn.Linear, nn.Embedding)): + print("Initializing {}".format(m)) + # Slight difference here with the TF version which uses truncated_normal + # cf https://github.com/pytorch/pytorch/pull/5617 + m.weight.data.normal_(config.initializer_range) + self.apply(init_weights) + + def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): + all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) + sequence_output = all_encoder_layers[-1] + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + + if start_positions is not None and end_positions is not None: + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss, (start_logits, end_logits) + else: + return start_logits, end_logits