[T5] Enable naive Pipeline Parallelism training for T5 (#22535)

* enable PP for T5

* make fixup

* fix failing tests
This commit is contained in:
Younes Belkada 2023-04-03 17:55:37 +02:00 committed by GitHub
parent cab048fb35
commit d7a4f5becc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 0 deletions

View File

@ -1778,6 +1778,8 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

View File

@ -1746,6 +1746,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666