Hot-fix-mixstral-loss (#27948)

* fix loss computation

* compute on GPU if possible
This commit is contained in:
Arthur 2023-12-12 12:20:28 +01:00 committed by GitHub
parent 4b759da8be
commit 680c610f97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -95,7 +95,8 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
if isinstance(gate_logits, tuple):
# cat along the layers?
gate_logits = torch.cat(gate_logits, dim=0)
compute_device = gate_logits[0].device
gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0)
routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
routing_weights = routing_weights.softmax(dim=-1)