This commit is contained in:
Arthur 2025-07-03 11:41:30 +02:00
parent a9690f43fd
commit d462a8ea38
3 changed files with 4 additions and 5 deletions

View File

@ -565,7 +565,7 @@ class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model
@check_model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
@ -693,7 +693,7 @@ class CsmBackboneModel(CsmPreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@check_model_inputs
@auto_docstring
def forward(
self,

View File

@ -324,7 +324,7 @@ class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
return model_inputs
@check_model_inputs
@can_return_tuple
@auto_docstring
def forward(
self,
@ -413,7 +413,7 @@ class CsmBackboneModel(LlamaModel):
super().__init__(config)
self.embed_tokens = CsmBackboneModelEmbeddings(config)
@can_return_tuple
@check_model_inputs
@auto_docstring
def forward(self, **super_kwargs):
r"""

View File

@ -1023,7 +1023,6 @@ def check_model_inputs(func):
for k in capture_flags
}
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
print(recordable_keys)
if any(recordable_keys.values()):
capture_tasks = []
for key, layer_specs in capture_flags.items():