Fix TypeError: 'NoneType' object is not iterable for esm (#38667) (#38668)

Add post_init() calls to EsmForMaskedLM, EsmForTokenClassification and EsmForSequenceClassification.
This commit is contained in:
dbleyl 2025-06-09 11:23:20 -04:00 committed by GitHub
parent 11dca07a10
commit b9faf2f930
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1023,6 +1023,8 @@ class EsmForMaskedLM(EsmPreTrainedModel):
self.init_weights()
self.post_init()
def get_output_embeddings(self):
return self.lm_head.decoder
@ -1127,6 +1129,8 @@ class EsmForSequenceClassification(EsmPreTrainedModel):
self.init_weights()
self.post_init()
@auto_docstring
def forward(
self,
@ -1210,6 +1214,8 @@ class EsmForTokenClassification(EsmPreTrainedModel):
self.init_weights()
self.post_init()
@auto_docstring
def forward(
self,