mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
[T5
] Enable naive Pipeline Parallelism training for T5 (#22535)
* enable PP for T5 * make fixup * fix failing tests
This commit is contained in:
parent
cab048fb35
commit
d7a4f5becc
@ -1778,6 +1778,8 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
# move labels to correct device to enable PP
|
||||
labels = labels.to(lm_logits.device)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
|
@ -1746,6 +1746,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
# move labels to correct device to enable PP
|
||||
labels = labels.to(lm_logits.device)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user