ignore SQuAD targets outside of seq_length

This commit is contained in:
thomwolf 2018-11-05 14:18:48 +01:00
parent 1b99cdf71b
commit c3527cfbc4

View File

@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module):
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension
# If we are on multi-GPU, split add a dimension - if not this is a no-op
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1)
loss_fct = CrossEntropyLoss()
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) + 1
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2