mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
gpt2 multi-gpu fix (#23149)
Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
This commit is contained in:
parent
adb0760b5f
commit
510ad0a8b8
@ -1670,9 +1670,9 @@ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions = start_positions.clamp(0, ignored_index)
|
||||
|
Loading…
Reference in New Issue
Block a user