mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
more fixes to moonshine!
This commit is contained in:
parent
cfe62b6b95
commit
6eb5e53e75
@ -502,7 +502,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||||||
main_input_name = "input_values"
|
main_input_name = "input_values"
|
||||||
_can_record_outputs = {
|
_can_record_outputs = {
|
||||||
"attentions": (MoonshineAttention, 1, "self_attn"),
|
"attentions": (MoonshineAttention, 1, "self_attn"),
|
||||||
"hidden_states": (MoonshineDecoderLayer,),
|
"hidden_states": (MoonshineEncoderLayer,),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: MoonshineConfig):
|
def __init__(self, config: MoonshineConfig):
|
||||||
|
@ -537,7 +537,7 @@ class MoonshineEncoder(MoonshinePreTrainedModel):
|
|||||||
main_input_name = "input_values"
|
main_input_name = "input_values"
|
||||||
_can_record_outputs = {
|
_can_record_outputs = {
|
||||||
"attentions": (MoonshineAttention, 1, "self_attn"),
|
"attentions": (MoonshineAttention, 1, "self_attn"),
|
||||||
"hidden_states": (MoonshineDecoderLayer,),
|
"hidden_states": (MoonshineEncoderLayer,),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: MoonshineConfig):
|
def __init__(self, config: MoonshineConfig):
|
||||||
|
@ -1035,9 +1035,15 @@ def check_model_inputs(func):
|
|||||||
|
|
||||||
for name, module in self.named_modules():
|
for name, module in self.named_modules():
|
||||||
for key, specs in capture_tasks:
|
for key, specs in capture_tasks:
|
||||||
if isinstance(module, specs[0]) and len(specs) > 2 and specs[2] in name:
|
if isinstance(module, specs[0]):
|
||||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
if len(specs) > 2:
|
||||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
if specs[2] in name:
|
||||||
|
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||||
|
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||||
|
else:
|
||||||
|
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||||
|
|
||||||
|
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||||
|
|
||||||
outputs = func(self, *args, **kwargs)
|
outputs = func(self, *args, **kwargs)
|
||||||
for h in hooks:
|
for h in hooks:
|
||||||
|
Loading…
Reference in New Issue
Block a user