This commit is contained in:
Arthur 2025-07-03 15:54:14 +02:00
parent fbfaf040ee
commit 1f559c676f

View File

@ -980,8 +980,8 @@ class OutputRecorder:
layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn". layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
""" """
target_class: Type target_class: torch.nn.Module
index: Optional[int] = None index: Optional[int] = 0
layer_name: Optional[str] = None layer_name: Optional[str] = None
@ -1038,11 +1038,7 @@ def check_model_inputs(func):
) )
for k in capture_flags for k in capture_flags
} }
print(recordable_keys, capture_flags)
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
recordable_keys["output_mask_decoder_attentions"] = recordable_keys.get("output_attentions", None)
if "output_vision_attentions" not in recordable_keys:
recordable_keys["output_vision_attentions"] = recordable_keys.get("output_attentions", None)
if any(recordable_keys.values()): if any(recordable_keys.values()):
capture_tasks = [] capture_tasks = []
for key, layer_specs in capture_flags.items(): for key, layer_specs in capture_flags.items():
@ -1055,24 +1051,20 @@ def check_model_inputs(func):
for name, module in self.named_modules(): for name, module in self.named_modules():
for key, specs in capture_tasks: for key, specs in capture_tasks:
if isinstance(module, specs[0]): if not isinstance(specs, OutputRecorder):
if len(specs) > 2: specs = OutputRecorder(*specs) # by default hidden states is 0, attention weights is 1
if specs[2] in name: specs.index = specs.index if specs.index is not None else 0 if "hidden_states" in key else 1
print(f"Attaching hook to {name}, {specs[0]}") if isinstance(module, specs.target_class):
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) if specs.layer_name is not None and specs.layer_name in name:
hooks.append(register_hook_if_needed(module, hook_fn)) hook_fn = make_capture_fn(key, specs.index)
else: hooks.append(register_hook_if_needed(module, hook_fn))
print("Skipping hook for", name, specs[0], "because it does not match", specs[2])
else: else:
print(f"Attaching hook to {name}, {specs[0]}, key: {key}") hook_fn = make_capture_fn(key, specs.index)
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn)) hooks.append(register_hook_if_needed(module, hook_fn))
outputs = func(self, *args, **kwargs) outputs = func(self, *args, **kwargs)
for h in hooks: for h in hooks:
if h is not None: if h is not None:
h.remove() h.remove()
print(collected_outputs.keys())
for key in collected_outputs: for key in collected_outputs:
if key == "hidden_states": if key == "hidden_states":
if hasattr(outputs, "vision_hidden_states"): if hasattr(outputs, "vision_hidden_states"):