This commit is contained in:
Arthur 2025-06-30 14:38:25 +02:00
parent 124cd82968
commit 4a14287a60

View File

@ -1011,10 +1011,9 @@ def check_model_inputs(func):
def make_capture_fn(key, index): def make_capture_fn(key, index):
def capture_fn(module, input, output): def capture_fn(module, input, output):
if len(output) == 0: if not isinstance(output, tuple):
collected_outputs[key] += (output,) collected_outputs[key] += (output,)
elif output[index] is not None: elif output[index] is not None:
print(key, module.__class__, output[index].shape, output[0].shape)
collected_outputs[key] += (output[index],) collected_outputs[key] += (output[index],)
return capture_fn return capture_fn