Mamba & RecurrentGemma: enable strict signature (#31549)

* enable strict signature

* this should not have been deleted

* recurrent_gemma too
This commit is contained in:
Joao Gante 2024-07-08 15:48:32 +01:00 committed by GitHub
parent ae9dd02ee1
commit 594c1610fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 40 deletions

View File

@ -2692,13 +2692,12 @@ class GenerationMixin:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
@ -2919,6 +2918,10 @@ class GenerationMixin:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
# if sequential is True, split the input to batches of batch_size and run sequentially
if sequential:
if any(
@ -2944,24 +2947,13 @@ class GenerationMixin:
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
)
outputs_per_sub_batch = [
self(
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
]
outputs = stack_model_outputs(outputs_per_sub_batch)
else: # Unchanged original behavior
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
@ -3241,12 +3233,12 @@ class GenerationMixin:
# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
@ -3522,12 +3514,11 @@ class GenerationMixin:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs, return_dict=True)
if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
@ -3793,11 +3784,11 @@ class GenerationMixin:
model_inputs["num_logits_to_keep"] = candidate_length + 1
# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
outputs = self(**model_inputs)
# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present

View File

@ -545,7 +545,6 @@ class MambaModel(MambaPreTrainedModel):
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -673,7 +672,6 @@ class MambaForCausalLM(MambaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

View File

@ -684,7 +684,6 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel):
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -823,7 +822,6 @@ class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, CausalLMOutput]:
r"""
Args: