mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[NllbMoe
] Update code to properly support loss computation (#25429)
* update nllb_moe * fix * doc nits * nits * add a small test * ficup * remove adapted from
This commit is contained in:
parent
9264fc915a
commit
181d778f83
@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func with SwitchTransformers->NllbMoeModel
|
||||
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
|
||||
r"""
|
||||
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
||||
@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
|
||||
Returns:
|
||||
The auxiliary loss.
|
||||
"""
|
||||
if router_probs is None:
|
||||
return 0
|
||||
|
||||
num_experts = router_probs.shape[-1]
|
||||
|
||||
# cast the expert indices to int64, otherwise one-hot encoding will fail
|
||||
@ -699,7 +701,9 @@ class NllbMoeEncoderLayer(nn.Module):
|
||||
if self.is_sparse:
|
||||
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
||||
else:
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
# router_states set to None to track which layers have None gradients.
|
||||
hidden_states, router_states = self.ffn(hidden_states), None
|
||||
|
||||
hidden_states = self.ff_dropout(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
@ -830,7 +834,8 @@ class NllbMoeDecoderLayer(nn.Module):
|
||||
if self.is_sparse:
|
||||
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
|
||||
else:
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
hidden_states, router_states = self.ffn(hidden_states), None
|
||||
|
||||
hidden_states = self.ff_dropout(hidden_states)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
@ -1730,7 +1735,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
||||
|
||||
if output_router_logits:
|
||||
encoder_router_logits = outputs[-1]
|
||||
decoder_router_logits = outputs[5 if output_attentions else 3]
|
||||
decoder_router_logits = outputs[3 if output_attentions else 4]
|
||||
|
||||
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
|
||||
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
|
||||
@ -1775,7 +1780,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
||||
decoder_router_logits=outputs.decoder_router_logits,
|
||||
)
|
||||
|
||||
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration._unpack_router_logits
|
||||
def _unpack_router_logits(self, router_outputs):
|
||||
total_router_logits = []
|
||||
total_expert_indexes = []
|
||||
@ -1784,11 +1788,10 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
||||
router_logits, expert_indexes = router_output
|
||||
total_router_logits.append(router_logits)
|
||||
total_expert_indexes.append(expert_indexes)
|
||||
if len(total_expert_indexes) > 0:
|
||||
total_router_logits = torch.cat(total_router_logits, dim=1)
|
||||
if len(total_expert_indexes) > 0:
|
||||
torch.cat(total_expert_indexes, dim=1)
|
||||
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
|
||||
|
||||
total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
|
||||
total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
|
||||
return total_router_logits, total_expert_indexes
|
||||
|
||||
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
|
||||
def prepare_inputs_for_generation(
|
||||
|
@ -337,6 +337,16 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
model.generate(input_ids, attention_mask=attention_mask)
|
||||
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_get_loss(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
input_dict["output_router_logits"] = True
|
||||
input_dict["labels"] = input_dict["input_ids"]
|
||||
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
|
||||
out = model(**input_dict)
|
||||
self.assertIsNotNone(out.loss)
|
||||
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
|
||||
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
Loading…
Reference in New Issue
Block a user