This commit is contained in:
Arthur 2025-07-03 15:54:14 +02:00
parent fbfaf040ee
commit 1f559c676f

View File

@ -980,8 +980,8 @@ class OutputRecorder:
layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
"""
target_class: Type
index: Optional[int] = None
target_class: torch.nn.Module
index: Optional[int] = 0
layer_name: Optional[str] = None
@ -1038,11 +1038,7 @@ def check_model_inputs(func):
)
for k in capture_flags
}
print(recordable_keys, capture_flags)
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
recordable_keys["output_mask_decoder_attentions"] = recordable_keys.get("output_attentions", None)
if "output_vision_attentions" not in recordable_keys:
recordable_keys["output_vision_attentions"] = recordable_keys.get("output_attentions", None)
if any(recordable_keys.values()):
capture_tasks = []
for key, layer_specs in capture_flags.items():
@ -1055,24 +1051,20 @@ def check_model_inputs(func):
for name, module in self.named_modules():
for key, specs in capture_tasks:
if isinstance(module, specs[0]):
if len(specs) > 2:
if specs[2] in name:
print(f"Attaching hook to {name}, {specs[0]}")
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn))
else:
print("Skipping hook for", name, specs[0], "because it does not match", specs[2])
if not isinstance(specs, OutputRecorder):
specs = OutputRecorder(*specs) # by default hidden states is 0, attention weights is 1
specs.index = specs.index if specs.index is not None else 0 if "hidden_states" in key else 1
if isinstance(module, specs.target_class):
if specs.layer_name is not None and specs.layer_name in name:
hook_fn = make_capture_fn(key, specs.index)
hooks.append(register_hook_if_needed(module, hook_fn))
else:
print(f"Attaching hook to {name}, {specs[0]}, key: {key}")
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hook_fn = make_capture_fn(key, specs.index)
hooks.append(register_hook_if_needed(module, hook_fn))
outputs = func(self, *args, **kwargs)
for h in hooks:
if h is not None:
h.remove()
print(collected_outputs.keys())
for key in collected_outputs:
if key == "hidden_states":
if hasattr(outputs, "vision_hidden_states"):