mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
ignore SQuAD targets outside of seq_length
This commit is contained in:
parent
1b99cdf71b
commit
c3527cfbc4
10
modeling.py
10
modeling.py
@ -455,9 +455,15 @@ class BertForQuestionAnswering(nn.Module):
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
if start_positions is not None and end_positions is not None:
|
||||
start_positions = start_positions.squeeze(-1) # If we are on multi-GPU, split add a dimension
|
||||
# If we are on multi-GPU, split add a dimension - if not this is a no-op
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1) + 1
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
|
Loading…
Reference in New Issue
Block a user