From 2ce56d35f6054cd844980ed4265ca3289bb56e0d Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Wed, 28 Feb 2024 11:16:15 +0100 Subject: [PATCH] Disable Mixtral `output_router_logits` during inference (#29249) * Set output_router_logits=False in prepare_inputs_for_generation for mixtral * Add output_router_logits=False to prepare_inputs_for_generation for mixtral * Fix style --- src/transformers/models/mixtral/modeling_mixtral.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 674ace5f236..01ea7282d78 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1415,7 +1415,13 @@ class MixtralForCausalLM(MixtralPreTrainedModel): ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs, ): # Omit tokens covered by past_key_values if past_key_values is not None: @@ -1467,6 +1473,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel): "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, + "output_router_logits": output_router_logits, } ) return model_inputs