[NllbMoe] Update code to properly support loss computation (#25429)

* update nllb_moe

* fix

* doc nits

* nits

* add a small test

* ficup

* remove adapted from
This commit is contained in:
Arthur 2023-08-17 17:21:56 +02:00 committed by GitHub
parent 9264fc915a
commit 181d778f83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 10 deletions

View File

@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func with SwitchTransformers->NllbMoeModel
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
Returns:
The auxiliary loss.
"""
if router_probs is None:
return 0
num_experts = router_probs.shape[-1]
# cast the expert indices to int64, otherwise one-hot encoding will fail
@ -699,7 +701,9 @@ class NllbMoeEncoderLayer(nn.Module):
if self.is_sparse:
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
else:
hidden_states = self.ffn(hidden_states)
# router_states set to None to track which layers have None gradients.
hidden_states, router_states = self.ffn(hidden_states), None
hidden_states = self.ff_dropout(hidden_states)
hidden_states = residual + hidden_states
@ -830,7 +834,8 @@ class NllbMoeDecoderLayer(nn.Module):
if self.is_sparse:
hidden_states, router_states = self.ffn(hidden_states, attention_mask)
else:
hidden_states = self.ffn(hidden_states)
hidden_states, router_states = self.ffn(hidden_states), None
hidden_states = self.ff_dropout(hidden_states)
hidden_states = residual + hidden_states
@ -1730,7 +1735,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
if output_router_logits:
encoder_router_logits = outputs[-1]
decoder_router_logits = outputs[5 if output_attentions else 3]
decoder_router_logits = outputs[3 if output_attentions else 4]
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
@ -1775,7 +1780,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
decoder_router_logits=outputs.decoder_router_logits,
)
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration._unpack_router_logits
def _unpack_router_logits(self, router_outputs):
total_router_logits = []
total_expert_indexes = []
@ -1784,11 +1788,10 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
router_logits, expert_indexes = router_output
total_router_logits.append(router_logits)
total_expert_indexes.append(expert_indexes)
if len(total_expert_indexes) > 0:
total_router_logits = torch.cat(total_router_logits, dim=1)
if len(total_expert_indexes) > 0:
torch.cat(total_expert_indexes, dim=1)
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
return total_router_logits, total_expert_indexes
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation(

View File

@ -337,6 +337,16 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_get_loss(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
input_dict["output_router_logits"] = True
input_dict["labels"] = input_dict["input_ids"]
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
out = model(**input_dict)
self.assertIsNotNone(out.loss)
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
@require_torch
@require_sentencepiece