mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable naive Pipeline Parallelism training for Gpt neox japanese and san japanese (#22702)
Move labels to same device as logits
This commit is contained in:
parent
28c19ab58d
commit
0224aaf67f
@ -682,6 +682,9 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
|
@ -1236,6 +1236,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
|
||||
router_probs = None
|
||||
aux_loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
|
||||
if output_router_logits:
|
||||
|
Loading…
Reference in New Issue
Block a user