mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix csm!
This commit is contained in:
parent
a9690f43fd
commit
d462a8ea38
@ -565,7 +565,7 @@ class CsmDepthDecoderForCausalLM(CsmPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@check_model_inputs
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -693,7 +693,7 @@ class CsmBackboneModel(CsmPreTrainedModel):
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -324,7 +324,7 @@ class CsmDepthDecoderForCausalLM(LlamaForCausalLM, GenerationMixin):
|
||||
|
||||
return model_inputs
|
||||
|
||||
@check_model_inputs
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
@ -413,7 +413,7 @@ class CsmBackboneModel(LlamaModel):
|
||||
super().__init__(config)
|
||||
self.embed_tokens = CsmBackboneModelEmbeddings(config)
|
||||
|
||||
@can_return_tuple
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(self, **super_kwargs):
|
||||
r"""
|
||||
|
@ -1023,7 +1023,6 @@ def check_model_inputs(func):
|
||||
for k in capture_flags
|
||||
}
|
||||
recordable_keys["output_cross_attentions"] = recordable_keys.get("output_attentions", None)
|
||||
print(recordable_keys)
|
||||
if any(recordable_keys.values()):
|
||||
capture_tasks = []
|
||||
for key, layer_specs in capture_flags.items():
|
||||
|
Loading…
Reference in New Issue
Block a user