mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
more fixes to moonshine!
This commit is contained in:
parent
cfe62b6b95
commit
6eb5e53e75
@ -502,7 +502,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
||||
main_input_name = "input_values"
|
||||
_can_record_outputs = {
|
||||
"attentions": (MoonshineAttention, 1, "self_attn"),
|
||||
"hidden_states": (MoonshineDecoderLayer,),
|
||||
"hidden_states": (MoonshineEncoderLayer,),
|
||||
}
|
||||
|
||||
def __init__(self, config: MoonshineConfig):
|
||||
|
@ -537,7 +537,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
||||
main_input_name = "input_values"
|
||||
_can_record_outputs = {
|
||||
"attentions": (MoonshineAttention, 1, "self_attn"),
|
||||
"hidden_states": (MoonshineDecoderLayer,),
|
||||
"hidden_states": (MoonshineEncoderLayer,),
|
||||
}
|
||||
|
||||
def __init__(self, config: MoonshineConfig):
|
||||
|
@ -1035,9 +1035,15 @@ def check_model_inputs(func):
|
||||
|
||||
for name, module in self.named_modules():
|
||||
for key, specs in capture_tasks:
|
||||
if isinstance(module, specs[0]) and len(specs) > 2 and specs[2] in name:
|
||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||
if isinstance(module, specs[0]):
|
||||
if len(specs) > 2:
|
||||
if specs[2] in name:
|
||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||
else:
|
||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||
|
||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||
|
||||
outputs = func(self, *args, **kwargs)
|
||||
for h in hooks:
|
||||
|
Loading…
Reference in New Issue
Block a user