mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix
This commit is contained in:
parent
98739ba418
commit
124cd82968
@ -283,9 +283,8 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
@ -294,7 +293,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)[0]
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
|
@ -1011,7 +1011,10 @@ def check_model_inputs(func):
|
||||
|
||||
def make_capture_fn(key, index):
|
||||
def capture_fn(module, input, output):
|
||||
if output[index] is not None:
|
||||
if len(output) == 0:
|
||||
collected_outputs[key] += (output,)
|
||||
elif output[index] is not None:
|
||||
print(key, module.__class__, output[index].shape, output[0].shape)
|
||||
collected_outputs[key] += (output[index],)
|
||||
|
||||
return capture_fn
|
||||
|
Loading…
Reference in New Issue
Block a user