mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
nicer!
This commit is contained in:
parent
fbfaf040ee
commit
1f559c676f
@ -980,8 +980,8 @@ class OutputRecorder:
|
|||||||
layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
|
layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_class: Type
|
target_class: torch.nn.Module
|
||||||
index: Optional[int] = None
|
index: Optional[int] = 0
|
||||||
layer_name: Optional[str] = None
|
layer_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@ -1038,11 +1038,7 @@ def check_model_inputs(func):
|
|||||||
)
|
)
|
||||||
for k in capture_flags
|
for k in capture_flags
|
||||||
}
|
}
|
||||||
print(recordable_keys, capture_flags)
|
|
||||||
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
|
|
||||||
recordable_keys["output_mask_decoder_attentions"] = recordable_keys.get("output_attentions", None)
|
|
||||||
if "output_vision_attentions" not in recordable_keys:
|
|
||||||
recordable_keys["output_vision_attentions"] = recordable_keys.get("output_attentions", None)
|
|
||||||
if any(recordable_keys.values()):
|
if any(recordable_keys.values()):
|
||||||
capture_tasks = []
|
capture_tasks = []
|
||||||
for key, layer_specs in capture_flags.items():
|
for key, layer_specs in capture_flags.items():
|
||||||
@ -1055,24 +1051,20 @@ def check_model_inputs(func):
|
|||||||
|
|
||||||
for name, module in self.named_modules():
|
for name, module in self.named_modules():
|
||||||
for key, specs in capture_tasks:
|
for key, specs in capture_tasks:
|
||||||
if isinstance(module, specs[0]):
|
if not isinstance(specs, OutputRecorder):
|
||||||
if len(specs) > 2:
|
specs = OutputRecorder(*specs) # by default hidden states is 0, attention weights is 1
|
||||||
if specs[2] in name:
|
specs.index = specs.index if specs.index is not None else 0 if "hidden_states" in key else 1
|
||||||
print(f"Attaching hook to {name}, {specs[0]}")
|
if isinstance(module, specs.target_class):
|
||||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
if specs.layer_name is not None and specs.layer_name in name:
|
||||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
hook_fn = make_capture_fn(key, specs.index)
|
||||||
else:
|
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||||
print("Skipping hook for", name, specs[0], "because it does not match", specs[2])
|
|
||||||
else:
|
else:
|
||||||
print(f"Attaching hook to {name}, {specs[0]}, key: {key}")
|
hook_fn = make_capture_fn(key, specs.index)
|
||||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
|
||||||
|
|
||||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||||
outputs = func(self, *args, **kwargs)
|
outputs = func(self, *args, **kwargs)
|
||||||
for h in hooks:
|
for h in hooks:
|
||||||
if h is not None:
|
if h is not None:
|
||||||
h.remove()
|
h.remove()
|
||||||
print(collected_outputs.keys())
|
|
||||||
for key in collected_outputs:
|
for key in collected_outputs:
|
||||||
if key == "hidden_states":
|
if key == "hidden_states":
|
||||||
if hasattr(outputs, "vision_hidden_states"):
|
if hasattr(outputs, "vision_hidden_states"):
|
||||||
|
Loading…
Reference in New Issue
Block a user