mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Hot-fix-mixstral-loss (#27948)
* fix loss computation * compute on GPU if possible
This commit is contained in:
parent
4b759da8be
commit
680c610f97
@ -95,7 +95,8 @@ def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tenso
|
||||
|
||||
if isinstance(gate_logits, tuple):
|
||||
# cat along the layers?
|
||||
gate_logits = torch.cat(gate_logits, dim=0)
|
||||
compute_device = gate_logits[0].device
|
||||
gate_logits = torch.cat([gate.to(compute_device) for gate in gate_logits], dim=0)
|
||||
|
||||
routing_weights, selected_experts = torch.topk(gate_logits, top_k, dim=-1)
|
||||
routing_weights = routing_weights.softmax(dim=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user