mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
finally
This commit is contained in:
parent
124cd82968
commit
4a14287a60
@ -1011,10 +1011,9 @@ def check_model_inputs(func):
|
|||||||
|
|
||||||
def make_capture_fn(key, index):
|
def make_capture_fn(key, index):
|
||||||
def capture_fn(module, input, output):
|
def capture_fn(module, input, output):
|
||||||
if len(output) == 0:
|
if not isinstance(output, tuple):
|
||||||
collected_outputs[key] += (output,)
|
collected_outputs[key] += (output,)
|
||||||
elif output[index] is not None:
|
elif output[index] is not None:
|
||||||
print(key, module.__class__, output[index].shape, output[0].shape)
|
|
||||||
collected_outputs[key] += (output[index],)
|
collected_outputs[key] += (output[index],)
|
||||||
|
|
||||||
return capture_fn
|
return capture_fn
|
||||||
|
Loading…
Reference in New Issue
Block a user