Enable naive Pipeline Parallelism training for Gpt neox japanese and san japanese (#22702)

Move labels to same device as logits
This commit is contained in:
Mayank Agarwal 2023-04-11 18:36:17 +05:30 committed by GitHub
parent 28c19ab58d
commit 0224aaf67f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 0 deletions

View File

@ -682,6 +682,9 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()

View File

@ -1236,6 +1236,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
router_probs = None
aux_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if output_router_logits: