mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Disable Mixtral output_router_logits
during inference (#29249)
* Set output_router_logits=False in prepare_inputs_for_generation for mixtral * Add output_router_logits=False to prepare_inputs_for_generation for mixtral * Fix style
This commit is contained in:
parent
8a8a0a4ae0
commit
2ce56d35f6
@ -1415,7 +1415,13 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_router_logits=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Omit tokens covered by past_key_values
|
||||
if past_key_values is not None:
|
||||
@ -1467,6 +1473,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"output_router_logits": output_router_logits,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
Loading…
Reference in New Issue
Block a user