This commit is contained in:
Arthur 2025-06-30 14:36:56 +02:00
parent 98739ba418
commit 124cd82968
2 changed files with 6 additions and 4 deletions

View File

@ -283,9 +283,8 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn(
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
@ -294,7 +293,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)[0]
)
hidden_states = residual + hidden_states
# Fully Connected

View File

@ -1011,7 +1011,10 @@ def check_model_inputs(func):
def make_capture_fn(key, index):
def capture_fn(module, input, output):
if output[index] is not None:
if len(output) == 0:
collected_outputs[key] += (output,)
elif output[index] is not None:
print(key, module.__class__, output[index].shape, output[0].shape)
collected_outputs[key] += (output[index],)
return capture_fn