[ESM] fix accelerate tests for esmfold (#20387)

* fix `accelerate` tests for esmfold

* cleaner solution
This commit is contained in:
Younes Belkada 2022-11-22 18:26:55 +01:00 committed by GitHub
parent d2357a0133
commit 2e17db8a86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -638,7 +638,7 @@ class EsmPreTrainedModel(PreTrainedModel):
config_class = EsmConfig
base_model_prefix = "esm"
_no_split_modules = ["EsmLayer"]
_no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):

View File

@ -1956,9 +1956,9 @@ class EsmFoldingTrunk(nn.Module):
for recycle_idx in range(no_recycles):
with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
# === Recycling ===
recycle_s = self.recycle_s_norm(recycle_s.detach())
recycle_z = self.recycle_z_norm(recycle_z.detach())
recycle_z += self.recycle_disto(recycle_bins.detach())
recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
@ -2207,6 +2207,9 @@ class EsmForProteinFolding(EsmPreTrainedModel):
return EsmForProteinFoldingOutput(**structure)
def af2_idx_to_esm_idx(self, aa, mask):
# avoid indexing on different devices
if self.af2_to_esm.device != aa.device:
self.af2_to_esm = self.af2_to_esm.to(aa.device)
aa = (aa + 1).masked_fill(mask != 1, 0)
return self.af2_to_esm[aa]