mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[SwitchTransformers
] Fix return values (#24300)
* clean history * remove other changes * fix * fix coipes
This commit is contained in:
parent
0b7b4429c7
commit
ba3fb4b8d7
@ -1348,7 +1348,7 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
|
||||
total_router_logits = []
|
||||
total_expert_indexes = []
|
||||
for router_output in router_outputs:
|
||||
if router_output[0] is not None:
|
||||
if len(router_output[0].shape) > 1:
|
||||
router_logits, expert_indexes = router_output
|
||||
total_router_logits.append(router_logits)
|
||||
total_expert_indexes.append(expert_indexes)
|
||||
|
@ -798,7 +798,7 @@ class SwitchTransformersBlock(nn.Module):
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, router_tuple = hidden_states
|
||||
else:
|
||||
router_tuple = (None,)
|
||||
router_tuple = (torch.tensor([0]),)
|
||||
|
||||
# clamp inf values to enable fp16 training
|
||||
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
|
||||
@ -1683,50 +1683,45 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
||||
decoder_z_loss = None
|
||||
decoder_aux_loss = None
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
# todo check in the config if router loss enables
|
||||
|
||||
if output_router_logits:
|
||||
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
||||
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(
|
||||
encoder_outputs.router_probs
|
||||
)
|
||||
if output_router_logits:
|
||||
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
||||
if self.encoder.config.encoder_sparse_step > 1:
|
||||
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_outputs[-1])
|
||||
encoder_z_loss = router_z_loss_func(encoder_router_logits)
|
||||
encoder_router_probs = nn.Softmax(dim=-1)(encoder_router_logits)
|
||||
encoder_aux_loss = load_balancing_loss_func(encoder_router_probs, encoder_expert_indexes)
|
||||
else:
|
||||
encoder_z_loss = 0
|
||||
encoder_aux_loss = 0
|
||||
|
||||
decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(
|
||||
decoder_outputs.router_probs
|
||||
)
|
||||
if self.decoder.config.decoder_sparse_step > 1:
|
||||
decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_outputs[-1])
|
||||
decoder_z_loss = router_z_loss_func(decoder_router_logits)
|
||||
decoder_router_probs = nn.Softmax(dim=-1)(decoder_router_logits)
|
||||
decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes)
|
||||
else:
|
||||
decoder_z_loss = 0
|
||||
decoder_aux_loss = 0
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
||||
# move labels to correct device to enable PP
|
||||
labels = labels.to(lm_logits.device)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
||||
|
||||
if output_router_logits and labels is not None:
|
||||
if output_router_logits:
|
||||
z_loss = self.router_z_loss_coef * (encoder_z_loss + decoder_z_loss)
|
||||
aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss)
|
||||
loss = loss + z_loss + aux_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,)
|
||||
if output_router_logits: # only return the loss if they are not None
|
||||
output += (
|
||||
encoder_z_loss,
|
||||
encoder_aux_loss,
|
||||
decoder_z_loss,
|
||||
decoder_aux_loss,
|
||||
*decoder_outputs[1:],
|
||||
*encoder_outputs,
|
||||
)
|
||||
else:
|
||||
output += (*decoder_outputs[1:], *encoder_outputs)
|
||||
if output_router_logits:
|
||||
output += (encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss)
|
||||
output += (*decoder_outputs[1:], *encoder_outputs)
|
||||
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqMoEOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
@ -1738,18 +1733,18 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
decoder_router_logits=decoder_outputs.router_probs,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
encoder_router_logits=encoder_outputs.router_probs,
|
||||
decoder_router_logits=decoder_outputs.router_probs,
|
||||
)
|
||||
|
||||
def _unpack_router_logits(self, router_outputs):
|
||||
total_router_logits = []
|
||||
total_expert_indexes = []
|
||||
for router_output in router_outputs:
|
||||
if router_output[0] is not None:
|
||||
if len(router_output[0].shape) > 1:
|
||||
router_logits, expert_indexes = router_output
|
||||
total_router_logits.append(router_logits)
|
||||
total_expert_indexes.append(expert_indexes)
|
||||
|
Loading…
Reference in New Issue
Block a user