add accelerate support for ESM (#20379)

This commit is contained in:
Younes Belkada 2022-11-22 14:06:00 +01:00 committed by GitHub
parent c0fe912840
commit ac3952b443
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -638,6 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class = EsmConfig
base_model_prefix = "esm"
_no_split_modules = ["EsmLayer"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):