mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
fix cross attention outputs!
This commit is contained in:
parent
6eb5e53e75
commit
a9690f43fd
@ -1022,7 +1022,8 @@ def check_model_inputs(func):
|
|||||||
f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
|
f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
|
||||||
for k in capture_flags
|
for k in capture_flags
|
||||||
}
|
}
|
||||||
|
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
|
||||||
|
print(recordable_keys)
|
||||||
if any(recordable_keys.values()):
|
if any(recordable_keys.values()):
|
||||||
capture_tasks = []
|
capture_tasks = []
|
||||||
for key, layer_specs in capture_flags.items():
|
for key, layer_specs in capture_flags.items():
|
||||||
@ -1038,13 +1039,13 @@ def check_model_inputs(func):
|
|||||||
if isinstance(module, specs[0]):
|
if isinstance(module, specs[0]):
|
||||||
if len(specs) > 2:
|
if len(specs) > 2:
|
||||||
if specs[2] in name:
|
if specs[2] in name:
|
||||||
|
print(f"Attaching hook to {name}, {specs[0]}, {module}")
|
||||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||||
else:
|
else:
|
||||||
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0)
|
||||||
|
|
||||||
hooks.append(register_hook_if_needed(module, hook_fn))
|
hooks.append(register_hook_if_needed(module, hook_fn))
|
||||||
|
|
||||||
outputs = func(self, *args, **kwargs)
|
outputs = func(self, *args, **kwargs)
|
||||||
for h in hooks:
|
for h in hooks:
|
||||||
if h is not None:
|
if h is not None:
|
||||||
@ -1058,7 +1059,7 @@ def check_model_inputs(func):
|
|||||||
outputs[key] = collected_outputs[key]
|
outputs[key] = collected_outputs[key]
|
||||||
elif key == "attentions":
|
elif key == "attentions":
|
||||||
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2:
|
if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2:
|
||||||
# we have cross attention states
|
# we have cross attention states return in the same buffer
|
||||||
outputs[key] = collected_outputs[key][0::2]
|
outputs[key] = collected_outputs[key][0::2]
|
||||||
outputs["cross_" + key] = collected_outputs[key][1::2]
|
outputs["cross_" + key] = collected_outputs[key][1::2]
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user