This commit is contained in:
Arthur 2025-06-30 14:38:25 +02:00
parent 124cd82968
commit 4a14287a60

View File

@ -1011,10 +1011,9 @@ def check_model_inputs(func):
def make_capture_fn(key, index):
def capture_fn(module, input, output):
if len(output) == 0:
if not isinstance(output, tuple):
collected_outputs[key] += (output,)
elif output[index] is not None:
print(key, module.__class__, output[index].shape, output[0].shape)
collected_outputs[key] += (output[index],)
return capture_fn