more fixes to moonshine!

This commit is contained in:
Arthur 2025-07-03 11:12:29 +02:00
parent cfe62b6b95
commit 6eb5e53e75
3 changed files with 11 additions and 5 deletions

View File

@ -502,7 +502,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
main_input_name = "input_values"
_can_record_outputs = {
"attentions": (MoonshineAttention, 1, "self_attn"),
"hidden_states": (MoonshineDecoderLayer,),
"hidden_states": (MoonshineEncoderLayer,),
}
def __init__(self, config: MoonshineConfig):

View File

@ -537,7 +537,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
main_input_name = "input_values"
_can_record_outputs = {
"attentions": (MoonshineAttention, 1, "self_attn"),
"hidden_states": (MoonshineDecoderLayer,),
"hidden_states": (MoonshineEncoderLayer,),
}
def __init__(self, config: MoonshineConfig):

View File

@ -1035,9 +1035,15 @@ def check_model_inputs(func):
for name, module in self.named_modules():
for key, specs in capture_tasks:
if isinstance(module, specs[0]) and len(specs) > 2 and specs[2] in name:
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn))
if isinstance(module, specs[0]):
if len(specs) > 2:
if specs[2] in name:
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn))
else:
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn))
outputs = func(self, *args, **kwargs)
for h in hooks: