more fixes to moonshine!

This commit is contained in:
Arthur 2025-07-03 11:12:29 +02:00
parent cfe62b6b95
commit 6eb5e53e75
3 changed files with 11 additions and 5 deletions

View File

@ -502,7 +502,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
main_input_name = "input_values" main_input_name = "input_values"
_can_record_outputs = { _can_record_outputs = {
"attentions": (MoonshineAttention, 1, "self_attn"), "attentions": (MoonshineAttention, 1, "self_attn"),
"hidden_states": (MoonshineDecoderLayer,), "hidden_states": (MoonshineEncoderLayer,),
} }
def __init__(self, config: MoonshineConfig): def __init__(self, config: MoonshineConfig):

View File

@ -537,7 +537,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
main_input_name = "input_values" main_input_name = "input_values"
_can_record_outputs = { _can_record_outputs = {
"attentions": (MoonshineAttention, 1, "self_attn"), "attentions": (MoonshineAttention, 1, "self_attn"),
"hidden_states": (MoonshineDecoderLayer,), "hidden_states": (MoonshineEncoderLayer,),
} }
def __init__(self, config: MoonshineConfig): def __init__(self, config: MoonshineConfig):

View File

@ -1035,9 +1035,15 @@ def check_model_inputs(func):
for name, module in self.named_modules(): for name, module in self.named_modules():
for key, specs in capture_tasks: for key, specs in capture_tasks:
if isinstance(module, specs[0]) and len(specs) > 2 and specs[2] in name: if isinstance(module, specs[0]):
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) if len(specs) > 2:
hooks.append(register_hook_if_needed(module, hook_fn)) if specs[2] in name:
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn))
else:
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn))
outputs = func(self, *args, **kwargs) outputs = func(self, *args, **kwargs)
for h in hooks: for h in hooks: