From 1f559c676f891f46fef3802b4b18e7de14f91837 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 3 Jul 2025 15:54:14 +0200 Subject: [PATCH] nicer! --- src/transformers/utils/generic.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 61cf0f9e042..e84c04d11bc 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -980,8 +980,8 @@ class OutputRecorder: layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn". """ - target_class: Type - index: Optional[int] = None + target_class: torch.nn.Module + index: Optional[int] = 0 layer_name: Optional[str] = None @@ -1038,11 +1038,7 @@ def check_model_inputs(func): ) for k in capture_flags } - print(recordable_keys, capture_flags) - recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None) - recordable_keys["output_mask_decoder_attentions"] = recordable_keys.get("output_attentions", None) - if "output_vision_attentions" not in recordable_keys: - recordable_keys["output_vision_attentions"] = recordable_keys.get("output_attentions", None) + if any(recordable_keys.values()): capture_tasks = [] for key, layer_specs in capture_flags.items(): @@ -1055,24 +1051,20 @@ def check_model_inputs(func): for name, module in self.named_modules(): for key, specs in capture_tasks: - if isinstance(module, specs[0]): - if len(specs) > 2: - if specs[2] in name: - print(f"Attaching hook to {name}, {specs[0]}") - hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) - hooks.append(register_hook_if_needed(module, hook_fn)) - else: - print("Skipping hook for", name, specs[0], "because it does not match", specs[2]) + if not isinstance(specs, OutputRecorder): + specs = OutputRecorder(*specs) # by default hidden states is 0, attention weights is 1 + specs.index = specs.index if specs.index is not None else 0 if "hidden_states" in key else 1 + if isinstance(module, specs.target_class): + if specs.layer_name is not None and specs.layer_name in name: + hook_fn = make_capture_fn(key, specs.index) + hooks.append(register_hook_if_needed(module, hook_fn)) else: - print(f"Attaching hook to {name}, {specs[0]}, key: {key}") - hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) - + hook_fn = make_capture_fn(key, specs.index) hooks.append(register_hook_if_needed(module, hook_fn)) outputs = func(self, *args, **kwargs) for h in hooks: if h is not None: h.remove() - print(collected_outputs.keys()) for key in collected_outputs: if key == "hidden_states": if hasattr(outputs, "vision_hidden_states"):