fix cross attention outputs!

This commit is contained in:
Arthur 2025-07-03 11:32:48 +02:00
parent 6eb5e53e75
commit a9690f43fd

View File

@ -1022,7 +1022,8 @@ def check_model_inputs(func):
f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False)) f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
for k in capture_flags for k in capture_flags
} }
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
print(recordable_keys)
if any(recordable_keys.values()): if any(recordable_keys.values()):
capture_tasks = [] capture_tasks = []
for key, layer_specs in capture_flags.items(): for key, layer_specs in capture_flags.items():
@ -1038,13 +1039,13 @@ def check_model_inputs(func):
if isinstance(module, specs[0]): if isinstance(module, specs[0]):
if len(specs) > 2: if len(specs) > 2:
if specs[2] in name: if specs[2] in name:
print(f"Attaching hook to {name}, {specs[0]}, {module}")
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn)) hooks.append(register_hook_if_needed(module, hook_fn))
else: else:
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
hooks.append(register_hook_if_needed(module, hook_fn)) hooks.append(register_hook_if_needed(module, hook_fn))
outputs = func(self, *args, **kwargs) outputs = func(self, *args, **kwargs)
for h in hooks: for h in hooks:
if h is not None: if h is not None:
@ -1058,7 +1059,7 @@ def check_model_inputs(func):
outputs[key] = collected_outputs[key] outputs[key] = collected_outputs[key]
elif key == "attentions": elif key == "attentions":
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2: if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2:
# we have cross attention states # we have cross attention states return in the same buffer
outputs[key] = collected_outputs[key][0::2] outputs[key] = collected_outputs[key][0::2]
outputs["cross_" + key] = collected_outputs[key][1::2] outputs["cross_" + key] = collected_outputs[key][1::2]
else: else: