diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 495cc971bba..c646c2d5cd6 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -1022,7 +1022,8 @@ def check_model_inputs(func): f"output_{k}": all_args.get(f"output_{k}", getattr(self.config, f"output_{k}", False)) for k in capture_flags } - + recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None) + print(recordable_keys) if any(recordable_keys.values()): capture_tasks = [] for key, layer_specs in capture_flags.items(): @@ -1038,13 +1039,13 @@ def check_model_inputs(func): if isinstance(module, specs[0]): if len(specs) > 2: if specs[2] in name: + print(f"Attaching hook to {name}, {specs[0]}, {module}") hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) hooks.append(register_hook_if_needed(module, hook_fn)) else: hook_fn = make_capture_fn(key, specs[1] if len(specs) > 1 else 0) hooks.append(register_hook_if_needed(module, hook_fn)) - outputs = func(self, *args, **kwargs) for h in hooks: if h is not None: @@ -1058,7 +1059,7 @@ def check_model_inputs(func): outputs[key] = collected_outputs[key] elif key == "attentions": if isinstance(capture_flags[key], list) and len(capture_flags[key]) == 2: - # we have cross attention states + # we have cross attention states return in the same buffer outputs[key] = collected_outputs[key][0::2] outputs["cross_" + key] = collected_outputs[key][1::2] else: