continue to fix distributed_type from TPU to XLA in LM examples (#38652)

This commit is contained in:
PT0X0E 2025-06-18 22:47:45 +08:00
parent 9cd7570f34
commit 91842a6900
4 changed files with 4 additions and 4 deletions

View File

@ -625,7 +625,7 @@ def main():
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.

View File

@ -531,7 +531,7 @@ def main():
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.

View File

@ -729,7 +729,7 @@ def main():
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.

View File

@ -568,7 +568,7 @@ def main():
)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
if accelerator.distributed_type == DistributedType.XLA:
model.tie_weights()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.